File size: 14,456 Bytes
77320e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 |
import re
import time
from typing import Union, List, Dict
from werkzeug.datastructures import FileStorage
from .. import BaseAgent
from ...exceptions.exceptions import InternalErrorException, LLMException, SandboxException
from ...schemas import (
AgentType, AgentRequest, AgentFinish, AgentAction, AgentResponse,
BaseAgentResponse, AgentObservation, RunCodeOutput, MediaFile
)
from ...tools import PythonSandBoxToolResponse, AsyncPythonSandBoxTool
from ...utils import get_logger, replace_latex_format, extract_and_replace_url, \
OBSERVATION_PREFIX_CN, OBSERVATION_PREFIX_EN, AGENT_FAILED_CN, AGENT_FAILED_EN, \
TOOL_INPUT_PREFIX_CN, TOOL_INPUT_PREFIX_EN
SAND_BOX_PLUGIN_NAME = 'python_code_sandbox'
FINAL_ANSWER_INDICATORS = ["Final Answer:", "[END]", "The final Answer", "final answer"]
CODE_BLOCK_START_TAG = '```python'
CODE_BLOCK_TAG = '```'
logger = get_logger()
SAND_BOX_PLUGIN_NAME = 'python_code_sandbox'
FINAL_ANSWER_INDICATORS = ["Final Answer:", "[END]", "The final Answer", "final answer"]
CODE_BLOCK_START_TAG = '```python'
CODE_BLOCK_TAG = '```'
STOP_WORD = ['Observation:']
logger = get_logger()
class AsyncReactAgent(BaseAgent):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._name = self._name or "AsyncReactAgent"
self._type = AgentType.react
self.__intermediate_steps: List[BaseAgentResponse] = []
@property
def intermediate_steps(self):
return self.__intermediate_steps
def run(self, *args, **kwargs):
pass
async def sync_to_sandbox(self, file: Union[str, Dict, FileStorage]):
sandbox_plugin = self.plugins_map.get(SAND_BOX_PLUGIN_NAME)
if not isinstance(sandbox_plugin, (AsyncPythonSandBoxTool, AsyncPythonSandBoxTool)):
raise InternalErrorException("SandBox client is not ready for agent, please check init logic.")
return await sandbox_plugin.sync_to_sandbox(file)
async def async_run(self, agent_req: AgentRequest):
instruction = '\n'.join(message.content for message in agent_req.messages)
async for response in self._chat(instruction, is_cn=agent_req.is_cn):
yield response
async def _chat(self, instruction: str, is_cn=False, max_iterations=10,
max_single_step_iterations=3):
current_iteration = 0
for _ in range(max_iterations):
current_iteration += 1
llm_response = await self._single_round_thought(instruction,
max_llm_iteration=max_single_step_iterations,
is_cn=is_cn)
logger.info("Round {} of {}, [LLM raw output]:\n{}\n\n[Formatted output]:\n{}\n"
.format(current_iteration, max_iterations, llm_response.raw_output,
llm_response.formatted_output))
yield self.create_agent_response(llm_response.formatted_output, [], llm_response.raw_output)
if isinstance(llm_response, AgentFinish):
logger.info("Find final answer, stop iteration.")
break
self.intermediate_steps.append(llm_response)
action_response, cur_output_files = await self._process_agent_action(llm_response, current_iteration,
max_iterations, is_cn)
logger.info("Round {} of {}, [Plugin raw output]:\n{}\n[Formatted output]:\n{}\n"
.format(current_iteration, max_iterations, action_response.raw_output,
action_response.formatted_output))
self.intermediate_steps.append(action_response)
yield self.create_agent_response(action_response.formatted_output,
cur_output_files,
action_response.raw_output)
logger.info(f"Finished iteration in {current_iteration}.")
# TODO update logic to not be sandbox specific, sandbox related logic should be handled in sandbox client
async def _process_agent_action(self, response, current_iteration, max_iterations, is_cn: bool = False):
try:
response.tool = 'python_code_sandbox'
action_response = await self.get_plugin_tool_async_function()[response.tool](response.tool_input)
logger.info(
f"Step {current_iteration} of {max_iterations}. Got agent observation raw output:\n"
f"{action_response.output_text}")
if "STDERR" in action_response.output_text:
formatted_output = self._process_sandbox_output(action_response.output_text)
else:
formatted_output = action_response.output_text
formatted_output = replace_latex_format(formatted_output)
observation_prefix = OBSERVATION_PREFIX_CN if is_cn else OBSERVATION_PREFIX_EN
formatted_output = f"{observation_prefix}\n{formatted_output}\n"
action_observation = AgentObservation(tool=response.tool,
formatted_output=formatted_output,
raw_output=action_response.output_text)
cur_output_files = self._get_output_files(action_response)
return action_observation, cur_output_files
except Exception as e:
logger.error(f"Error occurred while executing tool {response.tool} with input {response.tool_input}. "
f"Error: {str(e)}", exc_info=True)
# TODO: We hard code here as we only have one tool
raise SandboxException("Error occurred while running the tool") from e
def _compose_prompt(self, instruction) -> str:
"""
Compose the prompt from template, worker description, examples and instruction.
"""
agent_scratchpad = self.prompt_template.construct_scratchpad(self.__intermediate_steps)
tool_description = self._get_plugin_description()
tool_names = ", ".join(list(self.plugins_map.keys()))
if self.prompt_template is None:
raise InternalErrorException("Agent prompt is none, please check init process")
return self.prompt_template.format(
instruction=instruction,
agent_scratchpad=agent_scratchpad,
tool_description=tool_description,
tool_names=tool_names
)
async def _single_round_thought(self, instruction: str, max_llm_iteration=3, is_cn: bool = False) -> \
Union[AgentAction, AgentFinish]:
llm_iteration_count = 0
llm_response = None
while llm_iteration_count <= max_llm_iteration:
llm_iteration_count += 1
try:
llm_response = await self._get_llm_response(instruction)
action_response = self._parse_output(llm_response.content, is_cn)
return action_response
except Exception as e:
logger.error("LLM iteration {} out of {} failed. Error: {}".
format(llm_iteration_count, max_llm_iteration, str(e)), exc_info=True)
if llm_iteration_count > max_llm_iteration:
logger.error("LLM iteration {} exceed max retry {}. Aborting".
format(llm_iteration_count, max_llm_iteration))
return AgentFinish(formatted_output=AGENT_FAILED_CN if is_cn else AGENT_FAILED_EN,
raw_output=str(llm_response))
async def _get_llm_response(self, instruction: str):
prompt = self._compose_prompt(instruction)
logger.info("Send prompt to LLM:\n{}".format(prompt))
response = await self.llm.async_completion(prompt)
if response.state == "error":
raise LLMException("Failed to retrieve response from LLM, error: {}".format(str(response.content)))
logger.info("Got response from llm, raw response content: \n{}".format(response.content))
return response
def _parse_output(self, llm_output: str, is_cn: bool = False) -> Union[AgentAction, AgentFinish]:
for stop_word in STOP_WORD:
if stop_word in llm_output:
llm_output = llm_output.split(stop_word)[0].rstrip()
break
# Check for Final Answer, if it is final, then just return
for indicator in FINAL_ANSWER_INDICATORS:
if indicator in llm_output:
# got final answer and remove the indicator
parts = llm_output.split(indicator)
# formatted_output = ''.join(parts[:-1]).strip()
formatted_output = ''.join(parts).strip()
formatted_output = replace_latex_format(formatted_output)
return AgentFinish(raw_output=llm_output, formatted_output=formatted_output)
# Updated regex pattern for capturing the expected input format
ACTION_REGEX_1 = r"(.*?)\n?Action:\s*(.*?)\n?Action\s*Input:\s*```python\n(.*?)```(.*?)$|(.*?)\n?'''(\w+)\n?(.*?)\n?'''(.*?)$"
ACTION_REGEX_2 = r"(.*?)\n?Action:\s*(.*?)\n?Action\s*Input:\s*```py\n(.*?)```(.*?)$|(.*?)\n?'''(\w+)\n?(.*?)\n?'''(.*?)$"
action_match = re.search(ACTION_REGEX_1, llm_output, re.DOTALL) or re.search(ACTION_REGEX_2, llm_output, re.DOTALL)
# Find action, context, and action input, build action response
if action_match:
context = action_match.group(1).strip()
action_tool_description = action_match.group(2).strip()
action_input = action_match.group(3).strip()
# Format code
# TODO: currently we only have one plugin which is sandbox, update to support multiple tools
format_code_block = self._format_code_block(action_input)
prefix = TOOL_INPUT_PREFIX_CN if is_cn else TOOL_INPUT_PREFIX_EN
formatted_output = "{}\n{}\n{}\n".format(context, prefix, format_code_block)
formatted_output = replace_latex_format(formatted_output)
return AgentAction(tool=action_tool_description,
tool_input=format_code_block,
formatted_output=formatted_output,
raw_output=llm_output)
# Not final answer and not action, raise exception
if not re.search(r"Action\s*:", llm_output, re.DOTALL):
raise LLMException(f"Missing 'Action' in LLM output: `{llm_output}`")
elif not re.search(r"Action\s*Input\s*:", llm_output, re.DOTALL):
raise LLMException(f"Missing 'Action Input' in LLM output: `{llm_output}`")
else:
raise LLMException(f"Unrecognized LLM output format: `{llm_output}`")
def _format_code_block(self, tool_input):
stripped_tool_input = tool_input.strip()
if stripped_tool_input.startswith(CODE_BLOCK_START_TAG) and stripped_tool_input.endswith(CODE_BLOCK_TAG):
if not stripped_tool_input.startswith(CODE_BLOCK_START_TAG + '\n'):
stripped_tool_input = CODE_BLOCK_START_TAG + '\n' + stripped_tool_input[len(CODE_BLOCK_START_TAG):] + \
'\n'
formatted_code = stripped_tool_input
elif stripped_tool_input.startswith(CODE_BLOCK_TAG) and not stripped_tool_input.startswith(
CODE_BLOCK_START_TAG) and stripped_tool_input.endswith(CODE_BLOCK_TAG):
formatted_code = CODE_BLOCK_START_TAG + '\n' + stripped_tool_input[len(CODE_BLOCK_TAG):] + '\n'
else:
formatted_code = CODE_BLOCK_START_TAG + '\n' + stripped_tool_input + '\n' + CODE_BLOCK_TAG + '\n'
return formatted_code.encode("utf-8").decode("utf-8")
def _process_sandbox_output(self, output: str):
"""Function to process the result containing STDERR."""
if len(output) <= 1000:
return output
logger.info("Output contains error, original message is over 1000, trim it for response. ori output: \n{}".
format(output))
rows = output.split("\n")
# Get the first 500 characters, respecting line boundaries
top_segment = []
length = 0
for sub_p in rows:
if length + len(sub_p) > 500:
break
top_segment.append(sub_p)
length += len(sub_p)
# Get the last 500 characters, respecting line boundaries
bottom_segment = []
length = 0
for sub_p in reversed(rows):
if length + len(sub_p) > 500:
break
bottom_segment.insert(0, sub_p)
length += len(sub_p)
# Combine the segments with "......" in between
timed_output = "\n".join(top_segment + ["......"] + bottom_segment)
return timed_output
def _get_output_files(self, tool_response) -> list[MediaFile]:
output_files = []
if isinstance(tool_response, PythonSandBoxToolResponse) and isinstance(tool_response.raw_output, RunCodeOutput):
raw_output = tool_response.raw_output
if raw_output.code == 0 and not raw_output.data.is_partial:
result_data = raw_output.data.result
# TODO confirm if we still need output and format
if len(result_data.new_generated_files) > 0:
output_files.extend([MediaFile(tos_path=file.download_link) for file in
result_data.new_generated_files])
if len(result_data.code_output_result) > 0:
output_files.extend(
[MediaFile(tos_path=image.content) for image in result_data.code_output_result
if image.type == 'image'])
return output_files
def _replace_csv_path(self, input_string):
# Search for the pattern and replace it
pattern = r'pd\.read_csv\(["\'](.*\.csv)["\']\)'
replacement = "pd.read_csv('/path/to/your/dataset')"
updated_string = re.sub(pattern, replacement, input_string)
return updated_string
@staticmethod
def create_agent_response(formatted_output, output_files, raw_output):
return AgentResponse(output_text=formatted_output, output_files=output_files, raw_output_text=raw_output)
|