| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import asyncio |
| import json |
| import logging |
| import os |
| from enum import Enum |
| from typing import Any, Optional |
| from uuid import uuid4 |
|
|
| import torch |
| from PIL import Image |
|
|
| from verl.experimental.agent_loop.agent_loop import ( |
| AgentLoopBase, |
| AgentLoopOutput, |
| register, |
| ) |
| from verl.experimental.agent_loop.tool_parser import FunctionCall, ToolParser |
| from verl.experimental.agent_loop.utils import build_gpt_oss_tool_response_text |
| from verl.interactions.base import BaseInteraction |
| from verl.interactions.utils.interaction_registry import initialize_interactions_from_config |
| from verl.tools.schemas import ToolResponse |
| from verl.tools.utils.tool_registry import initialize_tools_from_config |
| from verl.utils.profiler import simple_timer |
| from verl.utils.rollout_trace import rollout_trace_op |
| from verl.workers.rollout.replica import TokenOutput |
|
|
| logger = logging.getLogger(__file__) |
| logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) |
|
|
|
|
| class AgentState(Enum): |
| PENDING = "pending" |
| GENERATING = "generating" |
| PROCESSING_TOOLS = "processing_tools" |
| TERMINATED = "terminated" |
| INTERACTING = "interacting" |
|
|
|
|
| class AgentData: |
| """Encapsulates all state variables for the agent loop. AgentData is passed to tool calling in case that |
| tool may need to access full history state. User can store any tool session data in `extra_fields`.""" |
|
|
| def __init__( |
| self, |
| messages: list[dict[str, Any]], |
| image_data: list[Image.Image], |
| video_data: list[tuple[torch.Tensor, dict[str, Any]]], |
| metrics: dict[str, Any], |
| request_id: str, |
| tools_kwargs: dict[str, Any], |
| interaction: Optional[BaseInteraction] = None, |
| interaction_kwargs: Optional[dict[str, Any]] = None, |
| ): |
| self.messages = messages |
| self.image_data = image_data |
| self.video_data = video_data |
| self.metrics = metrics |
| self.request_id = request_id |
| self.tools_kwargs = tools_kwargs |
| self.interaction = interaction |
| self.interaction_kwargs = interaction_kwargs or {} |
|
|
| |
| self.prompt_ids: list[int] = [] |
| self.response_ids: list[int] = [] |
| self.response_mask: list[int] = [] |
| self.response_logprobs: list[float] = [] |
| self.turn_scores: list[float] = [] |
| self.tool_rewards: list[float] = [] |
| self.user_turns = 0 |
| self.assistant_turns = 0 |
|
|
| |
| self.tool_calls: list[FunctionCall] = [] |
|
|
| self.routed_experts = None |
|
|
| |
| self.extra_fields: dict[str, Any] = {} |
|
|
|
|
| @register("tool_agent") |
| class ToolAgentLoop(AgentLoopBase): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| |
| self.max_user_turns = self.rollout_config.multi_turn.max_user_turns |
| self.max_assistant_turns = self.rollout_config.multi_turn.max_assistant_turns |
| self.max_parallel_calls = self.rollout_config.multi_turn.max_parallel_calls |
| self.max_tool_response_length = self.rollout_config.multi_turn.max_tool_response_length |
| self.tool_response_truncate_side = self.rollout_config.multi_turn.tool_response_truncate_side |
| tool_config_path = self.rollout_config.multi_turn.tool_config_path |
| tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else [] |
| self.tools = {tool.name: tool for tool in tool_list} |
| self.tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list] |
| self.tool_parser = ToolParser.get_tool_parser(self.rollout_config.multi_turn.format, self.tokenizer) |
| self.tool_parser_name = self.rollout_config.multi_turn.format |
|
|
| self.prompt_length = self.rollout_config.prompt_length |
| self.response_length = self.rollout_config.response_length |
|
|
| |
| self.interaction_config_file = self.rollout_config.multi_turn.interaction_config_path |
| if self.interaction_config_file: |
| self.interaction_map: dict[str, BaseInteraction] = self._initialize_interactions( |
| self.interaction_config_file |
| ) |
|
|
| @rollout_trace_op |
| async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput: |
| messages = list(kwargs["raw_prompt"]) |
|
|
| |
| multi_modal_data = await self.process_vision_info(messages) |
| images = multi_modal_data.get("images") |
| videos = multi_modal_data.get("videos") |
|
|
| metrics = {} |
| request_id = uuid4().hex |
| tools_kwargs = kwargs.get("tools_kwargs", {}) |
|
|
| |
| interaction = None |
| interaction_kwargs = {} |
| if self.interaction_config_file: |
| interaction_kwargs = kwargs["extra_info"]["interaction_kwargs"] |
| if "name" not in interaction_kwargs: |
| raise ValueError("'name' key is required in interaction_kwargs") |
| interaction_name = interaction_kwargs["name"] |
| if interaction_name not in self.interaction_map: |
| raise ValueError( |
| f"Interaction '{interaction_name}' not found in interaction_map. Available interactions: " |
| f"{list(self.interaction_map.keys())}" |
| ) |
| interaction = self.interaction_map[interaction_name] |
| await interaction.start_interaction(request_id, **interaction_kwargs) |
| |
| agent_data = AgentData( |
| messages=messages, |
| image_data=images, |
| video_data=videos, |
| metrics=metrics, |
| request_id=request_id, |
| tools_kwargs=tools_kwargs, |
| interaction=interaction, |
| interaction_kwargs=interaction_kwargs, |
| ) |
|
|
| |
| state = AgentState.PENDING |
| while state != AgentState.TERMINATED: |
| if state == AgentState.PENDING: |
| state = await self._handle_pending_state(agent_data, sampling_params) |
| elif state == AgentState.GENERATING: |
| state = await self._handle_generating_state(agent_data, sampling_params) |
| elif state == AgentState.PROCESSING_TOOLS: |
| state = await self._handle_processing_tools_state(agent_data) |
| elif state == AgentState.INTERACTING: |
| state = await self._handle_interacting_state(agent_data) |
| else: |
| logger.error(f"Invalid state: {state}") |
| state = AgentState.TERMINATED |
|
|
| |
| response_ids = agent_data.prompt_ids[-len(agent_data.response_mask) :] |
| prompt_ids = agent_data.prompt_ids[: len(agent_data.prompt_ids) - len(agent_data.response_mask)] |
| multi_modal_data = {} |
| if agent_data.image_data is not None: |
| multi_modal_data["images"] = agent_data.image_data |
| if agent_data.video_data is not None: |
| multi_modal_data["videos"] = agent_data.video_data |
|
|
| output: AgentLoopOutput = AgentLoopOutput( |
| prompt_ids=prompt_ids, |
| response_ids=response_ids[: self.response_length], |
| response_mask=agent_data.response_mask[: self.response_length], |
| multi_modal_data=multi_modal_data, |
| response_logprobs=agent_data.response_logprobs[: self.response_length] |
| if agent_data.response_logprobs |
| else None, |
| num_turns=agent_data.user_turns + agent_data.assistant_turns + 1, |
| metrics=agent_data.metrics, |
| routed_experts=agent_data.routed_experts, |
| extra_fields=agent_data.extra_fields, |
| ) |
| output.extra_fields.update({"turn_scores": agent_data.turn_scores, "tool_rewards": agent_data.tool_rewards}) |
| return output |
|
|
| async def _handle_pending_state(self, agent_data: AgentData, sampling_params: dict[str, Any]) -> AgentState: |
| """Handle the pending state: prepare the prompt and start generation.""" |
| prompt_ids = await self.apply_chat_template( |
| agent_data.messages, |
| tools=self.tool_schemas, |
| images=agent_data.image_data, |
| videos=agent_data.video_data, |
| ) |
| agent_data.prompt_ids = prompt_ids |
| return AgentState.GENERATING |
|
|
| async def _handle_generating_state( |
| self, agent_data: AgentData, sampling_params: dict[str, Any], ignore_termination: bool = False |
| ) -> AgentState: |
| """Handle the generating state: generate model response and check for tool calls.""" |
| add_messages: list[dict[str, Any]] = [] |
|
|
| with simple_timer("generate_sequences", agent_data.metrics): |
| output: TokenOutput = await self.server_manager.generate( |
| request_id=agent_data.request_id, |
| prompt_ids=agent_data.prompt_ids, |
| sampling_params=sampling_params, |
| image_data=agent_data.image_data, |
| video_data=agent_data.video_data, |
| ) |
| |
| if agent_data.metrics.get("num_preempted") is None: |
| agent_data.metrics["num_preempted"] = output.num_preempted if output.num_preempted is not None else -1 |
| |
| else: |
| agent_data.metrics["num_preempted"] += output.num_preempted if output.num_preempted is not None else 0 |
|
|
| if not agent_data.extra_fields: |
| agent_data.extra_fields.update(output.extra_fields) |
| else: |
| |
| max_global_steps = output.extra_fields.get("max_global_steps", None) |
| if max_global_steps: |
| agent_data.extra_fields["max_global_steps"] = max_global_steps |
|
|
| agent_data.assistant_turns += 1 |
| agent_data.response_ids = output.token_ids |
| agent_data.prompt_ids += agent_data.response_ids |
| agent_data.response_mask += [1] * len(agent_data.response_ids) |
| if output.log_probs: |
| agent_data.response_logprobs += output.log_probs |
|
|
| if output.routed_experts is not None: |
| agent_data.routed_experts = output.routed_experts |
|
|
| |
| if not ignore_termination and len(agent_data.response_mask) >= self.response_length: |
| return AgentState.TERMINATED |
| if self.max_assistant_turns and agent_data.assistant_turns >= self.max_assistant_turns: |
| return AgentState.TERMINATED |
| if self.max_user_turns and agent_data.user_turns >= self.max_user_turns: |
| return AgentState.TERMINATED |
|
|
| |
| tools = [tool.tool_schema for tool in self.tools.values()] |
| _, agent_data.tool_calls = await self.tool_parser.extract_tool_calls(agent_data.response_ids, tools) |
|
|
| |
| if self.interaction_config_file: |
| assistant_message = await self.loop.run_in_executor( |
| None, lambda: self.tokenizer.decode(agent_data.response_ids, skip_special_tokens=True) |
| ) |
| add_messages.append({"role": "assistant", "content": assistant_message}) |
| agent_data.messages.extend(add_messages) |
|
|
| |
| if agent_data.tool_calls: |
| return AgentState.PROCESSING_TOOLS |
| elif self.interaction_config_file: |
| return AgentState.INTERACTING |
| else: |
| return AgentState.TERMINATED |
|
|
| async def _handle_processing_tools_state(self, agent_data: AgentData) -> AgentState: |
| """Handle the processing tools state: execute tool calls and prepare tool responses.""" |
| add_messages: list[dict[str, Any]] = [] |
| new_images_this_turn: list[Any] = [] |
|
|
| tasks = [] |
| tool_call_names = [] |
| for tool_call in agent_data.tool_calls[: self.max_parallel_calls]: |
| tasks.append(self._call_tool(tool_call, agent_data.tools_kwargs, agent_data)) |
| tool_call_names.append(tool_call.name) |
|
|
| with simple_timer("tool_calls", agent_data.metrics): |
| responses = await asyncio.gather(*tasks) |
|
|
| |
| |
| for tool_response, tool_reward, _ in responses: |
| |
| if tool_response.image or tool_response.video: |
| |
| if not getattr(self.processor, "image_processor", None): |
| raise ValueError( |
| "Multimedia data can only be processed by `processor`, but the processor is None. " |
| "This error is often caused if you are using a LLM model but your tool returns multimodal " |
| "data. Plase use a vlm as the base model." |
| ) |
| content = [] |
| if tool_response.image: |
| content.append({"type": "image"}) |
| if tool_response.video: |
| content.append({"type": "video"}) |
| if tool_response.text: |
| content.append({"type": "text", "text": tool_response.text}) |
| message = {"role": "tool", "content": content} |
| else: |
| |
| message = {"role": "tool", "content": tool_response.text or ""} |
|
|
| add_messages.append(message) |
|
|
| |
| if tool_response.image: |
| |
| if isinstance(tool_response.image, list): |
| |
| for img in tool_response.image: |
| if img is not None: |
| new_images_this_turn.append(img) |
| else: |
| |
| if tool_response.image is not None: |
| new_images_this_turn.append(tool_response.image) |
|
|
| |
| if tool_response.video: |
| |
| logger.warning("Multimedia type 'video' is not currently supported. Only 'image' is supported.") |
| raise NotImplementedError( |
| "Multimedia type 'video' is not currently supported. Only 'image' is supported." |
| ) |
|
|
| if tool_reward is not None: |
| agent_data.tool_rewards.append(tool_reward) |
|
|
| agent_data.messages.extend(add_messages) |
|
|
| if self.tool_parser_name == "gpt-oss": |
| logger.info("manually format tool responses for gpt-oss") |
| tool_response_text = build_gpt_oss_tool_response_text(add_messages, tool_call_names) |
| response_ids = await self.loop.run_in_executor( |
| None, lambda: self.tokenizer.encode(tool_response_text, add_special_tokens=False) |
| ) |
| else: |
| |
| |
| images = new_images_this_turn if new_images_this_turn else None |
| videos = None |
| response_ids = await self.apply_chat_template( |
| add_messages, |
| images=images, |
| videos=videos, |
| remove_system_prompt=True, |
| ) |
|
|
| if len(agent_data.response_mask) + len(response_ids) >= self.response_length: |
| return AgentState.TERMINATED |
| |
|
|
| if new_images_this_turn: |
| if agent_data.image_data is None: |
| agent_data.image_data = [] |
| elif not isinstance(agent_data.image_data, list): |
| agent_data.image_data = [agent_data.image_data] |
| for img in new_images_this_turn: |
| agent_data.image_data.append(img) |
|
|
| agent_data.prompt_ids += response_ids |
| agent_data.response_mask += [0] * len(response_ids) |
| if agent_data.response_logprobs: |
| agent_data.response_logprobs += [0.0] * len(response_ids) |
| agent_data.user_turns += 1 |
| return AgentState.GENERATING |
|
|
| async def _handle_interacting_state(self, agent_data: AgentData) -> AgentState: |
| """Handle the interacting state: get user input from interaction.""" |
| ( |
| should_terminate_sequence, |
| interaction_responses, |
| reward, |
| metrics, |
| ) = await agent_data.interaction.generate_response( |
| agent_data.request_id, agent_data.messages, **agent_data.interaction_kwargs |
| ) |
| agent_data.user_turns += 1 |
|
|
| add_messages: list[dict[str, Any]] = [{"role": "user", "content": interaction_responses}] |
| agent_data.messages.extend(add_messages) |
|
|
| if reward is not None: |
| agent_data.turn_scores.append(reward) |
|
|
| |
| response_ids = await self.apply_chat_template( |
| add_messages, |
| remove_system_prompt=True, |
| ) |
|
|
| |
| agent_data.prompt_ids += response_ids |
| agent_data.response_mask += [0] * len(response_ids) |
| if agent_data.response_logprobs: |
| agent_data.response_logprobs += [0.0] * len(response_ids) |
|
|
| |
| |
| if should_terminate_sequence: |
| return AgentState.TERMINATED |
| else: |
| return AgentState.GENERATING |
|
|
| async def _call_tool( |
| self, tool_call: FunctionCall, tools_kwargs: dict[str, Any], agent_data: AgentData |
| ) -> tuple[ToolResponse, float, dict]: |
| """Call tool and return tool response.""" |
| tool, instance_id = None, None |
| try: |
| |
| tool_name = tool_call.name |
| tool_args = json.loads(tool_call.arguments) |
| tool = self.tools[tool_name] |
| kwargs = tools_kwargs.get(tool_name, {}) |
| instance_id, _ = await tool.create(create_kwargs=kwargs.get("create_kwargs", {})) |
| tool_execution_response, tool_reward, res = await tool.execute( |
| instance_id, tool_args, agent_data=agent_data |
| ) |
| except Exception as e: |
| logger.warning(f"Error when executing tool: {e}") |
| return ( |
| ToolResponse( |
| text=f"Error when executing tool: {e}", |
| ), |
| 0.0, |
| {}, |
| ) |
| finally: |
| if tool and instance_id: |
| await tool.release(instance_id) |
|
|
| tool_response_text = tool_execution_response.text |
| if tool_response_text and len(tool_response_text) > self.max_tool_response_length: |
| if self.tool_response_truncate_side == "left": |
| tool_response_text = tool_response_text[: self.max_tool_response_length] + "...(truncated)" |
| elif self.tool_response_truncate_side == "right": |
| tool_response_text = "(truncated)..." + tool_response_text[-self.max_tool_response_length :] |
| else: |
| length = self.max_tool_response_length // 2 |
| tool_response_text = tool_response_text[:length] + "...(truncated)..." + tool_response_text[-length:] |
|
|
| |
| tool_response_kwargs = {"text": tool_response_text} |
|
|
| |
| for attr_name in ["image", "video"]: |
| if hasattr(tool_execution_response, attr_name): |
| attr_value = getattr(tool_execution_response, attr_name) |
| if attr_value is not None: |
| tool_response_kwargs[attr_name] = attr_value |
|
|
| return ToolResponse(**tool_response_kwargs), tool_reward, res |
|
|
| def _initialize_interactions(self, interaction_config_file): |
| """Initialize interactions from configuration. |
| Returns: |
| dict[str, BaseInteraction]: A dictionary mapping interaction names to interaction instances. |
| """ |
| if interaction_config_file is None: |
| return {} |
|
|
| interaction_map = initialize_interactions_from_config(interaction_config_file) |
| return interaction_map |
|
|