| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import asyncio |
| import logging |
| import os |
| import random |
| from abc import ABC, abstractmethod |
| from typing import Any, Optional |
| from uuid import uuid4 |
|
|
| import hydra |
| import numpy as np |
| import ray |
| import torch |
| from cachetools import LRUCache |
| from omegaconf import DictConfig, OmegaConf |
| from PIL import Image |
| from pydantic import BaseModel, ConfigDict |
| from tensordict import TensorDict |
| from transformers import AutoProcessor, AutoTokenizer |
|
|
| from verl.experimental.agent_loop.prometheus_utils import update_prometheus_config |
| from verl.experimental.agent_loop.utils import resolve_config_path |
| from verl.protocol import DataProto |
| from verl.single_controller.ray.base import RayResourcePool, RayWorkerGroup |
| from verl.utils.chat_template import apply_chat_template, initialize_system_prompt |
| from verl.utils.config import omega_conf_to_dataclass |
| from verl.utils.dataset.rl_dataset import RLHFDataset, get_dataset_class |
| from verl.utils.model import compute_position_id_with_mask |
| from verl.utils.ray_utils import auto_await, get_event_loop |
| from verl.utils.rollout_trace import ( |
| RolloutTraceConfig, |
| rollout_trace_attr, |
| rollout_trace_op, |
| ) |
| from verl.utils.tokenizer import normalize_token_ids |
| from verl.workers.config import HFModelConfig, RolloutConfig |
| from verl.workers.rollout.arithmetic_sampling import get_arithmetic_code |
| from verl.workers.rollout.replica import TokenOutput, get_rollout_replica_class |
|
|
| logger = logging.getLogger(__file__) |
| logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) |
|
|
| DEFAULT_ROUTING_CACHE_SIZE = 10000 |
|
|
|
|
| @ray.remote |
| class GlobalRequestLoadBalancer: |
| """Global sticky-session + in-flight load balancer shared by all AgentLoopWorkers.""" |
|
|
| def __init__(self, server_actor_ids: list[str], max_cache_size: int = DEFAULT_ROUTING_CACHE_SIZE): |
| if not server_actor_ids: |
| raise ValueError("server_actor_ids must be non-empty") |
|
|
| self._inflight_requests: dict[str, int] = {sid: 0 for sid in server_actor_ids} |
| self._request_id_to_server: LRUCache = LRUCache(maxsize=max_cache_size) |
|
|
| def acquire_server(self, request_id: str) -> str: |
| """Acquire a server for the given request, reusing the same server for multi-turn conversations.""" |
| |
| if request_id in self._request_id_to_server: |
| server_id = self._request_id_to_server[request_id] |
| self._inflight_requests[server_id] += 1 |
| return server_id |
|
|
| |
| server_id = min(self._inflight_requests, key=self._inflight_requests.get) |
| self._request_id_to_server[request_id] = server_id |
| self._inflight_requests[server_id] += 1 |
| return server_id |
|
|
| def release_server(self, server_id: str) -> None: |
| """Release a server after a request completes, decrementing its inflight count.""" |
| if server_id not in self._inflight_requests: |
| raise ValueError(f"Invalid server_id for release: {server_id}") |
| if self._inflight_requests[server_id] <= 0: |
| raise ValueError(f"Release called with no inflight requests on server {server_id}") |
| self._inflight_requests[server_id] -= 1 |
|
|
|
|
| def _get_rollout_and_model_config(config: DictConfig) -> tuple[DictConfig, DictConfig]: |
| |
| if config.get("actor_rollout_ref"): |
| return config.actor_rollout_ref.rollout, config.actor_rollout_ref.model |
| else: |
| return config.rollout, config.model |
|
|
|
|
| class AsyncLLMServerManager: |
| """ |
| A class to manage multiple OpenAI compatible LLM servers. This class provides |
| - Load balance: least in-flight requests load balancing via global coordination |
| - Sticky session: send multi-turn chat completions to same server for automatic prefix caching |
| """ |
|
|
| def __init__( |
| self, |
| config: DictConfig, |
| servers: list[tuple[str, ray.actor.ActorHandle]], |
| load_balancer_handle: ray.actor.ActorHandle, |
| ): |
| """Initialize the AsyncLLMServerManager. |
| |
| Args: |
| config (DictConfig): whole config for main entrypoint. |
| servers (list[tuple[str, ray.actor.ActorHandle]]): (address, handle) pairs for each LLM server. |
| load_balancer_handle (ray.actor.ActorHandle): shared global load balancer actor. |
| """ |
| self.config = config |
| self._load_balancer = load_balancer_handle |
| self._server_id_to_handle: dict[str, ray.actor.ActorHandle] = dict(servers) |
|
|
| async def _acquire_server(self, request_id: str) -> tuple[str, ray.actor.ActorHandle]: |
| server_id = await self._load_balancer.acquire_server.remote(request_id=request_id) |
| handle = self._server_id_to_handle.get(server_id) |
| if handle is None: |
| raise RuntimeError(f"Unknown server_id returned by load balancer: {server_id}") |
| return server_id, handle |
|
|
| def _release_server(self, server_id: str) -> None: |
| |
| |
| self._load_balancer.release_server.remote(server_id=server_id) |
|
|
| @rollout_trace_op |
| async def generate( |
| self, |
| request_id, |
| *, |
| prompt_ids: list[int], |
| sampling_params: dict[str, Any], |
| image_data: Optional[list[Any]] = None, |
| video_data: Optional[list[Any]] = None, |
| ) -> TokenOutput: |
| """Generate tokens from prompt ids. |
| |
| Args: |
| request_id (str): request id for sticky session. |
| prompt_ids (List[int]): List of prompt token ids. |
| sampling_params (Dict[str, Any]): Sampling parameters for the chat completion. |
| |
| Returns: |
| TokenOutput: token output |
| """ |
| server_id, server = await self._acquire_server(request_id) |
| try: |
| output: TokenOutput = await server.generate.remote( |
| request_id=uuid4().hex, |
| prompt_ids=prompt_ids, |
| sampling_params=sampling_params, |
| image_data=image_data, |
| video_data=video_data, |
| ) |
| return output |
| finally: |
| self._release_server(server_id) |
|
|
|
|
| class AgentLoopMetrics(BaseModel): |
| """Agent loop performance metrics.""" |
|
|
| generate_sequences: float = 0.0 |
| tool_calls: float = 0.0 |
| num_preempted: int = -1 |
| arithmetic_sampling_used: int = 0 |
| arithmetic_sampling_avg_samples_per_group: float = 0.0 |
| arithmetic_sampling_groups_triggered_fraction: float = 0.0 |
|
|
|
|
| class AgentLoopOutput(BaseModel): |
| """Agent loop output.""" |
|
|
| prompt_ids: list[int] |
| """Prompt token ids.""" |
| response_ids: list[int] |
| """Response token ids including LLM generated token, tool response token.""" |
| response_mask: list[int] |
| """Response mask, 1 for LLM generated token, 0 for tool response token.""" |
| response_logprobs: Optional[list[float]] = None |
| """Log probabilities for the response tokens.""" |
| routed_experts: Optional[Any] = None |
| """Routed experts for the total tokens.""" |
| multi_modal_data: Optional[dict[str, Any]] = None |
| """Multi-modal data for multi-modal tools.""" |
| reward_score: Optional[float] = None |
| """Reward score for the trajectory.""" |
| num_turns: int = 0 |
| """Number of chat turns, including user, assistant, tool.""" |
| metrics: AgentLoopMetrics |
| """Auxiliary performance metrics""" |
| extra_fields: dict[str, Any] = {} |
| """Extra fields for dynamic addition.""" |
|
|
|
|
| class _InternalAgentLoopOutput(AgentLoopOutput): |
| """Internal agent loop output with padded sequences.""" |
|
|
| model_config = ConfigDict(arbitrary_types_allowed=True) |
|
|
| prompt_ids: torch.Tensor |
| """Padded prompt token ids.""" |
| response_ids: torch.Tensor |
| """Padded response token ids.""" |
| input_ids: torch.Tensor |
| """Padded input ids(prompt_ids + response_ids).""" |
| position_ids: torch.Tensor |
| """Padded position ids.""" |
| response_mask: torch.Tensor |
| """Padded response mask.""" |
| attention_mask: torch.Tensor |
| """Padded attention mask.""" |
| response_logprobs: Optional[torch.Tensor] = None |
| """Padded log probabilities for the response tokens.""" |
| routed_experts: Optional[torch.Tensor] = None |
| """Padded routed experts for the total tokens.""" |
| multi_modal_inputs: Optional[dict[str, torch.Tensor]] = None |
| """Multi-modal inputs for processors (e.g., pixel_values, image_grid_thw).""" |
| extra_fields: dict[str, Any] = {} |
| """Extra fields for dynamic addition.""" |
|
|
|
|
| class DictConfigWrap: |
| """Wrapper for DictConfig to avoid hydra.utils.instantiate recursive resolve.""" |
|
|
| def __init__(self, config: DictConfig): |
| self.config = config |
|
|
|
|
| class AgentLoopBase(ABC): |
| """An agent loop takes an input message, chat with OpenAI compatible LLM server and interact with various |
| environments. |
| |
| Args: |
| trainer_config (DictConfig): whole config for main entrypoint. |
| server_manager (AsyncLLMServerManager): OpenAI compatible LLM server manager. |
| tokenizer (AutoTokenizer): Tokenizer for tokenize messages. |
| processor (AutoProcessor): Processor for process messages. |
| dataset_cls (type[Dataset]): Dataset class for creating dataset, Defaults to RLHFDataset. |
| data_config (DictConfigWrap): Dataset config. |
| """ |
|
|
| def __init__( |
| self, |
| trainer_config: DictConfigWrap, |
| server_manager: AsyncLLMServerManager, |
| tokenizer: AutoTokenizer, |
| processor: AutoProcessor, |
| dataset_cls: type[RLHFDataset], |
| data_config: DictConfigWrap, |
| **kwargs, |
| ): |
| self.config = trainer_config.config |
| self.rollout_config, _ = _get_rollout_and_model_config(self.config) |
| self.server_manager = server_manager |
| self.tokenizer = tokenizer |
| self.processor = processor |
| self.dataset_cls = dataset_cls |
| self.data_config = data_config.config |
| self.apply_chat_template_kwargs = self.data_config.get("apply_chat_template_kwargs", {}) |
| self.system_prompt = initialize_system_prompt(self.tokenizer, **self.apply_chat_template_kwargs) |
| self.loop = get_event_loop() |
|
|
| async def process_vision_info(self, messages: list[dict]) -> dict: |
| """Extract images and videos from messages. |
| |
| Args: |
| messages (list[dict]): Input messages. |
| |
| Returns: |
| dict: Multi-modal data with keys "images" and "videos". |
| """ |
| multi_modal_data = {} |
| if self.processor is not None: |
| images, videos = await self.dataset_cls.process_vision_info( |
| messages, image_patch_size=self.processor.image_processor.patch_size, config=self.data_config |
| ) |
| if images is not None: |
| multi_modal_data["images"] = images |
| if videos is not None: |
| multi_modal_data["videos"] = videos |
|
|
| return multi_modal_data |
|
|
| async def apply_chat_template( |
| self, |
| messages: list[dict], |
| tools: list[dict] = None, |
| images: list[Image.Image] = None, |
| videos: list[tuple[torch.Tensor, dict]] = None, |
| remove_system_prompt: bool = False, |
| ): |
| """Apply chat template to messages with optional tools, images, and videos. |
| |
| Args: |
| messages (list[dict]): Input messages. |
| tools (list[dict], optional): Tools schemas. Defaults to None. |
| images (list[Image.Image], optional): Input images. Defaults to None. |
| videos (list[tuple[torch.Tensor, dict]], optional): Input videos. Defaults to None. |
| remove_system_prompt (bool, optional): Whether to remove system prompt. Defaults to False. |
| |
| Returns: |
| list[int]: Prompt token ids. |
| """ |
| if self.processor is not None: |
| raw_prompt = await self.loop.run_in_executor( |
| None, |
| lambda: apply_chat_template( |
| self.processor, |
| messages, |
| tools=tools, |
| add_generation_prompt=True, |
| tokenize=False, |
| **self.apply_chat_template_kwargs, |
| ), |
| ) |
|
|
| |
| if videos is not None: |
| videos, video_metadatas = zip(*videos, strict=False) |
| videos, video_metadatas = list(videos), list(video_metadatas) |
| else: |
| video_metadatas = None |
|
|
| model_inputs = self.processor( |
| text=[raw_prompt], |
| images=images, |
| videos=videos, |
| video_metadata=video_metadatas, |
| return_tensors="pt", |
| do_sample_frames=False, |
| ) |
| prompt_ids = normalize_token_ids(model_inputs.pop("input_ids")) |
| else: |
| tokenized_prompt = await self.loop.run_in_executor( |
| None, |
| lambda: apply_chat_template( |
| self.tokenizer, |
| messages, |
| tools=tools, |
| add_generation_prompt=True, |
| tokenize=True, |
| **self.apply_chat_template_kwargs, |
| ), |
| ) |
| prompt_ids = normalize_token_ids(tokenized_prompt) |
|
|
| if remove_system_prompt: |
| prompt_ids = prompt_ids[len(self.system_prompt) :] |
|
|
| return prompt_ids |
|
|
| @abstractmethod |
| async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput: |
| """Run agent loop to interact with LLM server and environment. |
| |
| Args: |
| sampling_params (Dict[str, Any]): LLM sampling params. |
| **kwargs: dataset fields from `verl.utils.dataset.RLHFDataset`. |
| |
| Returns: |
| AgentLoopOutput: Agent loop output. |
| """ |
| raise NotImplementedError |
|
|
|
|
| """Agent loop registry: key is agent_name, value is a dict of agent loop config |
| used by hydra.utils.instantiate to initialize agent loop instance. |
| |
| https://hydra.cc/docs/advanced/instantiate_objects/overview/ |
| """ |
| _agent_loop_registry: dict[str, dict] = {} |
|
|
|
|
| def register(agent_name: str): |
| """Register agent loop class.""" |
|
|
| def decorator(subclass: type[AgentLoopBase]) -> type[AgentLoopBase]: |
| fqdn = f"{subclass.__module__}.{subclass.__qualname__}" |
| _agent_loop_registry[agent_name] = {"_target_": fqdn} |
| return subclass |
|
|
| return decorator |
|
|
|
|
| class AgentLoopWorker: |
| """Agent loop worker takes a batch of messages and run each message in an agent loop. |
| |
| Args: |
| config (DictConfig): whole config for main entrypoint. |
| servers (list[tuple[str, ray.actor.ActorHandle]]): (address, handle) pairs for each LLM server. |
| reward_loop_worker_handles (List[ray.actor.ActorHandle]): Actor handles for streaming reward computation. |
| """ |
|
|
| def __init__( |
| self, |
| config: DictConfig, |
| servers: list[tuple[str, ray.actor.ActorHandle]], |
| load_balancer_handle: ray.actor.ActorHandle, |
| reward_loop_worker_handles: list[ray.actor.ActorHandle] = None, |
| ): |
| """Initialize agent loop manager. |
| Args: |
| config (DictConfig): YAML config. |
| servers (list[tuple[str, ray.actor.ActorHandle]]): (address, handle) pairs for each LLM server. |
| load_balancer_handle (ray.actor.ActorHandle): shared global load balancer actor. |
| reward_loop_worker_handles (list[ray.actor.ActorHandle]): Actor handles for streaming reward computation. |
| """ |
| self.config = config |
| rollout_config, model_config = _get_rollout_and_model_config(config) |
| self.rollout_config: RolloutConfig = omega_conf_to_dataclass(rollout_config) |
| self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config) |
|
|
| |
| if not hasattr(self, "server_manager"): |
| self.server_manager = AsyncLLMServerManager( |
| config, |
| servers, |
| load_balancer_handle=load_balancer_handle, |
| ) |
|
|
| self.dataset_cls = get_dataset_class(config.data) |
| self.reward_loop_worker_handles = reward_loop_worker_handles |
| self._warned_two_phase_missing_reward_loop = False |
|
|
| self.tokenizer = self.model_config.tokenizer |
| self.processor = self.model_config.processor |
|
|
| agent_loop_config_path = self.rollout_config.agent.agent_loop_config_path |
| if agent_loop_config_path: |
| resolved_path = resolve_config_path(agent_loop_config_path) |
| agent_loop_configs = OmegaConf.load(resolved_path) |
| for agent_loop_config in agent_loop_configs: |
| _agent_loop_registry[agent_loop_config.name] = agent_loop_config |
| if self.model_config.get("custom_chat_template", None) is not None: |
| if self.model_config.processor is not None: |
| self.model_config.processor.chat_template = self.model_config.custom_chat_template |
| self.model_config.tokenizer.chat_template = self.model_config.custom_chat_template |
|
|
| trace_config = self.rollout_config.trace |
| RolloutTraceConfig.init( |
| self.rollout_config.trace.project_name, |
| self.rollout_config.trace.experiment_name, |
| trace_config.get("backend"), |
| trace_config.get("token2text", False), |
| trace_config.get("max_samples_per_step_per_worker", None), |
| ) |
|
|
| async def generate_sequences(self, batch: DataProto) -> DataProto: |
| """Generate sequences from agent loop. |
| |
| Args: |
| batch (DataProto): Input batch. |
| |
| Returns: |
| DataProto: Output batch. |
| - prompts: [bsz, prompt_length], prompt token ids from dataset. |
| - responses: [bsz, response_length], output token ids include response tokens |
| from LLM generation and observation tokens from tool_calls. |
| - response_mask: [bsz, response_length], 1 for LLM generated tokens, 0 for observation/padding tokens. |
| - input_ids: [bsz, prompt_length + response_length], whole sequence token ids, including prompt tokens |
| and response tokens. |
| - attention_mask: [bsz, prompt_length + response_length], 0 for padding tokens, 1 for other tokens. |
| - position_ids: [bsz, prompt_length + response_length], incremental position ids. |
| |
| For multi-turn conversations: |
| responses: |<- LLM generation ->|<- tool_calls ->|<- LLM generation ->|<- padding ->| |
| response_mask: | 1, 1, 1, ..., 1, 1 | 0, 0, .., 0, 0 | 1, 1, 1, ..., 1, 1 | 0, 0, ..., 0| |
| """ |
| config = self.rollout_config |
| sampling_params = dict( |
| temperature=config.temperature, |
| top_p=config.top_p, |
| top_k=config.top_k, |
| repetition_penalty=1.0, |
| logprobs=config.calculate_log_probs, |
| ) |
|
|
| |
| if batch.meta_info.get("validate", False): |
| sampling_params["top_p"] = config.val_kwargs.top_p |
| sampling_params["top_k"] = config.val_kwargs.top_k |
| sampling_params["temperature"] = config.val_kwargs.temperature |
|
|
| |
| if "agent_name" not in batch.non_tensor_batch: |
| default_agent_loop = config.agent.default_agent_loop |
| batch.non_tensor_batch["agent_name"] = np.array([default_agent_loop] * len(batch), dtype=object) |
|
|
| if "index" in batch.non_tensor_batch: |
| sample_keys = batch.non_tensor_batch["index"] |
| elif "uid" in batch.non_tensor_batch: |
| sample_keys = batch.non_tensor_batch["uid"] |
| else: |
| sample_keys = np.arange(len(batch)) |
|
|
| rollout_ns = batch.non_tensor_batch.get("rollout_n") |
|
|
| max_samples_per_worker = RolloutTraceConfig.get_instance().max_samples_per_step_per_worker |
|
|
| |
| |
| if max_samples_per_worker is not None: |
| unique_sample_indices = np.unique(sample_keys) |
| if max_samples_per_worker < len(unique_sample_indices): |
| selected_samples = set( |
| np.random.choice(unique_sample_indices, max_samples_per_worker, replace=False).tolist() |
| ) |
| traced_indices = set(i for i in range(len(batch)) if sample_keys[i] in selected_samples) |
| else: |
| traced_indices = set(range(len(batch))) |
| else: |
| traced_indices = set(range(len(batch))) |
|
|
| trajectory_info = get_trajectory_info( |
| batch.meta_info.get("global_steps", -1), |
| sample_keys.tolist(), |
| batch.meta_info.get("validate", False), |
| rollout_ns.tolist() if rollout_ns is not None else None, |
| ) |
|
|
| if self._should_use_two_phase_arithmetic(validate=batch.meta_info.get("validate", False)): |
| outputs = await self._generate_sequences_with_two_phase_arithmetic( |
| batch=batch, |
| sampling_params=sampling_params, |
| trajectory_info=trajectory_info, |
| sample_keys=sample_keys.tolist(), |
| traced_indices=traced_indices, |
| ) |
| else: |
| tasks = [] |
| for i in range(len(batch)): |
| trace_this_sample = i in traced_indices |
| kwargs = {k: v[i] for k, v in batch.non_tensor_batch.items()} |
| request_sampling_params = self._maybe_apply_arithmetic_sampling(sampling_params, trajectory_info[i]) |
| tasks.append( |
| asyncio.create_task( |
| self._run_agent_loop( |
| request_sampling_params, |
| trajectory_info[i], |
| trace=trace_this_sample, |
| **kwargs, |
| ) |
| ) |
| ) |
| outputs = await asyncio.gather(*tasks) |
|
|
| output = self._postprocess(outputs, input_non_tensor_batch=batch.non_tensor_batch) |
|
|
| return output |
|
|
| def _should_use_two_phase_arithmetic(self, *, validate: bool) -> bool: |
| arithmetic_config = self.rollout_config.arithmetic_sampling |
| if not arithmetic_config.enable: |
| return False |
| if validate and not arithmetic_config.apply_to_validation: |
| return False |
| if arithmetic_config.probe_count <= 0: |
| return False |
| if self.reward_loop_worker_handles is None: |
| if not self._warned_two_phase_missing_reward_loop: |
| logger.warning( |
| "Two-phase arithmetic sampling requested, but no reward loop workers are available. " |
| "Falling back to the original arithmetic behavior." |
| ) |
| self._warned_two_phase_missing_reward_loop = True |
| return False |
| return True |
|
|
| async def _generate_sequences_with_two_phase_arithmetic( |
| self, |
| *, |
| batch: DataProto, |
| sampling_params: dict[str, Any], |
| trajectory_info: list[dict[str, Any]], |
| sample_keys: list[Any], |
| traced_indices: set[int], |
| ) -> list[_InternalAgentLoopOutput]: |
| group_indices_list = group_consecutive_indices(sample_keys) |
| group_tasks = [ |
| asyncio.create_task( |
| self._run_two_phase_group( |
| batch=batch, |
| sampling_params=sampling_params, |
| trajectory_info=trajectory_info, |
| group_indices=group_indices, |
| traced_indices=traced_indices, |
| ) |
| ) |
| for group_indices in group_indices_list |
| ] |
| group_results = await asyncio.gather(*group_tasks) |
|
|
| outputs = [output for group_outputs, _ in group_results for output in group_outputs] |
| arithmetic_counts = np.array([count for _, count in group_results], dtype=np.float32) |
| avg_samples_per_group = float(arithmetic_counts.mean()) if len(arithmetic_counts) > 0 else 0.0 |
| groups_triggered_fraction = float((arithmetic_counts > 0).mean()) if len(arithmetic_counts) > 0 else 0.0 |
|
|
| for output in outputs: |
| output.metrics.arithmetic_sampling_avg_samples_per_group = avg_samples_per_group |
| output.metrics.arithmetic_sampling_groups_triggered_fraction = groups_triggered_fraction |
|
|
| return outputs |
|
|
| async def _run_two_phase_group( |
| self, |
| *, |
| batch: DataProto, |
| sampling_params: dict[str, Any], |
| trajectory_info: list[dict[str, Any]], |
| group_indices: list[int], |
| traced_indices: set[int], |
| ) -> tuple[list[_InternalAgentLoopOutput], int]: |
| arithmetic_config = self.rollout_config.arithmetic_sampling |
| probe_count = min(arithmetic_config.probe_count, len(group_indices)) |
| outputs: list[Optional[_InternalAgentLoopOutput]] = [None] * len(group_indices) |
|
|
| phase1_tasks = [] |
| for batch_idx in group_indices[:probe_count]: |
| phase1_tasks.append( |
| asyncio.create_task( |
| self._run_group_sample( |
| batch=batch, |
| sampling_params=sampling_params, |
| trajectory=trajectory_info[batch_idx], |
| batch_idx=batch_idx, |
| trace=batch_idx in traced_indices, |
| use_arithmetic=False, |
| ) |
| ) |
| ) |
| phase1_outputs = await asyncio.gather(*phase1_tasks) |
| for local_idx, output in enumerate(phase1_outputs): |
| outputs[local_idx] = output |
|
|
| should_fallback = probe_count < len(group_indices) and all( |
| not self._reward_passed(output.reward_score, arithmetic_config.pass_reward_threshold) |
| for output in phase1_outputs |
| ) |
|
|
| arithmetic_used = 0 |
| phase2_tasks = [] |
| for local_idx, batch_idx in enumerate(group_indices[probe_count:], start=probe_count): |
| use_arithmetic = should_fallback |
| if use_arithmetic: |
| arithmetic_used += 1 |
| phase2_tasks.append( |
| asyncio.create_task( |
| self._run_group_sample( |
| batch=batch, |
| sampling_params=sampling_params, |
| trajectory=trajectory_info[batch_idx], |
| batch_idx=batch_idx, |
| trace=batch_idx in traced_indices, |
| use_arithmetic=use_arithmetic, |
| ) |
| ) |
| ) |
| if phase2_tasks: |
| phase2_outputs = await asyncio.gather(*phase2_tasks) |
| for local_idx, output in enumerate(phase2_outputs, start=probe_count): |
| outputs[local_idx] = output |
|
|
| finalized_outputs = [] |
| for local_idx, output in enumerate(outputs): |
| assert output is not None |
| output.metrics.arithmetic_sampling_used = 1 if should_fallback and local_idx >= probe_count else 0 |
| output.extra_fields["arithmetic_sampling_used"] = bool(output.metrics.arithmetic_sampling_used) |
| finalized_outputs.append(output) |
|
|
| return finalized_outputs, arithmetic_used |
|
|
| async def _run_group_sample( |
| self, |
| *, |
| batch: DataProto, |
| sampling_params: dict[str, Any], |
| trajectory: dict[str, Any], |
| batch_idx: int, |
| trace: bool, |
| use_arithmetic: bool, |
| ) -> _InternalAgentLoopOutput: |
| kwargs = {k: v[batch_idx] for k, v in batch.non_tensor_batch.items()} |
| request_sampling_params = self._maybe_apply_arithmetic_sampling( |
| sampling_params, |
| trajectory, |
| force_apply=use_arithmetic, |
| ) |
| return await self._run_agent_loop( |
| request_sampling_params, |
| trajectory, |
| trace=trace, |
| **kwargs, |
| ) |
|
|
| @staticmethod |
| def _reward_passed(score: Optional[float], threshold: float) -> bool: |
| return score is not None and float(score) > threshold |
|
|
| def _maybe_apply_arithmetic_sampling( |
| self, |
| sampling_params: dict[str, Any], |
| trajectory: dict[str, Any], |
| *, |
| force_apply: Optional[bool] = None, |
| ) -> dict[str, Any]: |
| arithmetic_config = self.rollout_config.get("arithmetic_sampling", {}) or {} |
| if isinstance(arithmetic_config, dict): |
| arithmetic_enable = arithmetic_config.get("enable", False) |
| arithmetic_apply_to_validation = arithmetic_config.get("apply_to_validation", False) |
| arithmetic_group_size = arithmetic_config.get("group_size") |
| arithmetic_seed = arithmetic_config.get("seed", 0) |
| else: |
| arithmetic_enable = arithmetic_config.enable |
| arithmetic_apply_to_validation = arithmetic_config.apply_to_validation |
| arithmetic_group_size = arithmetic_config.group_size |
| arithmetic_seed = arithmetic_config.seed |
|
|
| if force_apply is None: |
| if not arithmetic_enable: |
| return sampling_params |
| if trajectory["validate"] and not arithmetic_apply_to_validation: |
| return sampling_params |
| elif not force_apply: |
| return sampling_params |
|
|
| group_size = arithmetic_group_size |
| if group_size is None: |
| group_size = self.rollout_config.val_kwargs.n if trajectory["validate"] else self.rollout_config.n |
|
|
| request_sampling_params = dict(sampling_params) |
| extra_args = dict(request_sampling_params.get("extra_args") or {}) |
| extra_args["arithmetic_code"] = get_arithmetic_code( |
| group_size=group_size, |
| seed=arithmetic_seed, |
| rollout_n=int(trajectory["rollout_n"]), |
| ) |
| request_sampling_params["extra_args"] = extra_args |
| return request_sampling_params |
|
|
| async def _run_agent_loop( |
| self, |
| sampling_params: dict[str, Any], |
| trajectory: dict[str, Any], |
| *, |
| agent_name: str, |
| trace: bool = True, |
| **kwargs, |
| ) -> _InternalAgentLoopOutput: |
| with rollout_trace_attr( |
| step=trajectory["step"], |
| sample_index=trajectory["sample_index"], |
| rollout_n=trajectory["rollout_n"], |
| validate=trajectory["validate"], |
| name="agent_loop", |
| trace=trace, |
| ): |
| assert agent_name in _agent_loop_registry, ( |
| f"Agent loop {agent_name} not registered, registered agent loops: {_agent_loop_registry.keys()}" |
| ) |
|
|
| agent_loop_config = _agent_loop_registry[agent_name] |
| agent_loop = hydra.utils.instantiate( |
| config=agent_loop_config, |
| trainer_config=DictConfigWrap(config=self.config), |
| server_manager=self.server_manager, |
| tokenizer=self.tokenizer, |
| processor=self.processor, |
| dataset_cls=self.dataset_cls, |
| data_config=DictConfigWrap(self.config.data), |
| ) |
| output: AgentLoopOutput = await agent_loop.run(sampling_params, **kwargs) |
| return await self._agent_loop_postprocess(output, **kwargs) |
|
|
| async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalAgentLoopOutput: |
| """Perform post-processing operations on the output of each individual agent loop.""" |
| output.extra_fields["raw_prompt"] = kwargs["raw_prompt"] |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| self.tokenizer.padding_side = "left" |
| prompt_output = self.tokenizer.pad( |
| {"input_ids": output.prompt_ids}, |
| padding="max_length", |
| max_length=self.rollout_config.prompt_length, |
| return_tensors="pt", |
| return_attention_mask=True, |
| ) |
| if prompt_output["input_ids"].dim() == 1: |
| prompt_output["input_ids"] = prompt_output["input_ids"].unsqueeze(0) |
| prompt_output["attention_mask"] = prompt_output["attention_mask"].unsqueeze(0) |
|
|
| self.tokenizer.padding_side = "right" |
| response_output = self.tokenizer.pad( |
| {"input_ids": output.response_ids}, |
| padding="max_length", |
| max_length=self.rollout_config.response_length, |
| return_tensors="pt", |
| return_attention_mask=True, |
| ) |
| if response_output["input_ids"].dim() == 1: |
| response_output["input_ids"] = response_output["input_ids"].unsqueeze(0) |
| response_output["attention_mask"] = response_output["attention_mask"].unsqueeze(0) |
|
|
| response_mask_output = self.tokenizer.pad( |
| {"input_ids": output.response_mask}, |
| padding="max_length", |
| max_length=self.rollout_config.response_length, |
| return_tensors="pt", |
| return_attention_mask=False, |
| ) |
| if response_mask_output["input_ids"].dim() == 1: |
| response_mask_output["input_ids"] = response_mask_output["input_ids"].unsqueeze(0) |
|
|
| response_logprobs = None |
| if output.response_logprobs is not None: |
| pad_size = self.rollout_config.response_length - len(output.response_logprobs) |
| response_logprobs = torch.tensor(output.response_logprobs + [0.0] * pad_size).unsqueeze(0) |
|
|
| response_mask = response_mask_output["input_ids"] * response_output["attention_mask"] |
| attention_mask = torch.cat([prompt_output["attention_mask"], response_output["attention_mask"]], dim=1) |
| input_ids = torch.cat([prompt_output["input_ids"], response_output["input_ids"]], dim=1) |
|
|
| routed_experts = None |
| if output.routed_experts is not None: |
| total_length = input_ids.shape[1] |
| length, layer_num, topk_num = output.routed_experts.shape |
| if isinstance(output.routed_experts, np.ndarray): |
| routed_experts_array = output.routed_experts |
| if not routed_experts_array.flags.writeable: |
| routed_experts_array = routed_experts_array.copy() |
| experts_tensor = torch.from_numpy(routed_experts_array) |
| elif isinstance(output.routed_experts, torch.Tensor): |
| experts_tensor = output.routed_experts |
| else: |
| raise TypeError(f"Unsupported type for routed_experts: {type(output.routed_experts)}") |
| routed_experts = torch.zeros(1, total_length, layer_num, topk_num, dtype=experts_tensor.dtype) |
|
|
| |
| start_pos = prompt_output["input_ids"].shape[1] - len(output.prompt_ids) |
| end_pos = min(start_pos + length, total_length) |
|
|
| |
| if start_pos < 0 or end_pos > total_length: |
| raise ValueError( |
| f"Invalid position range: start_pos={start_pos}, end_pos={end_pos}, total_length={total_length}" |
| ) |
|
|
| routed_experts[:, start_pos:end_pos] = experts_tensor.unsqueeze(0) |
|
|
| multi_modal_inputs = self._compute_multi_modal_inputs(output, input_ids) |
| position_ids = self._compute_position_ids(input_ids, attention_mask, multi_modal_inputs) |
| await self._compute_score( |
| output, |
| prompts=prompt_output["input_ids"], |
| responses=response_output["input_ids"], |
| attention_mask=attention_mask, |
| input_ids=input_ids, |
| position_ids=position_ids, |
| kwargs=kwargs, |
| ) |
|
|
| return _InternalAgentLoopOutput( |
| prompt_ids=prompt_output["input_ids"], |
| response_ids=response_output["input_ids"], |
| input_ids=input_ids, |
| position_ids=position_ids, |
| response_mask=response_mask, |
| attention_mask=attention_mask, |
| response_logprobs=response_logprobs, |
| routed_experts=routed_experts, |
| multi_modal_inputs=multi_modal_inputs, |
| multi_modal_data=output.multi_modal_data, |
| reward_score=output.reward_score, |
| num_turns=output.num_turns, |
| metrics=output.metrics, |
| extra_fields=output.extra_fields, |
| ) |
|
|
| def _compute_multi_modal_inputs(self, output, input_ids) -> dict[str, torch.Tensor]: |
| """Compute multi-modal inputs with image and video.""" |
| multi_modal_inputs = {} |
| if self.processor is None: |
| return multi_modal_inputs |
|
|
| images = output.multi_modal_data.get("images") |
| videos = output.multi_modal_data.get("videos") |
| |
| if videos is not None: |
| videos, video_metadatas = zip(*videos, strict=False) |
| videos, video_metadatas = list(videos), list(video_metadatas) |
| else: |
| video_metadatas = None |
| current_text = self.tokenizer.decode(input_ids.squeeze(0), skip_special_tokens=True) |
| multi_modal_inputs = self.processor( |
| text=[current_text], |
| images=images, |
| videos=videos, |
| video_metadata=video_metadatas, |
| return_tensors="pt", |
| do_sample_frames=False, |
| ) |
| multi_modal_inputs.pop("input_ids", None) |
| multi_modal_inputs.pop("attention_mask", None) |
|
|
| |
| |
| multi_modal_inputs = dict(multi_modal_inputs.convert_to_tensors("pt")) |
| image_grid_thw = multi_modal_inputs.get("image_grid_thw") |
| if image_grid_thw is not None: |
| images_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]) |
| multi_modal_inputs["images_seqlens"] = images_seqlens |
| return multi_modal_inputs |
|
|
| def _compute_position_ids(self, input_ids, attention_mask, multi_modal_inputs) -> torch.Tensor: |
| """Compute position ids for multi-modal inputs.""" |
| if self.processor is None: |
| return compute_position_id_with_mask(attention_mask) |
|
|
| multi_modal_kwargs = { |
| "image_grid_thw": multi_modal_inputs.get("image_grid_thw"), |
| "video_grid_thw": multi_modal_inputs.get("video_grid_thw"), |
| } |
| |
| if multi_modal_inputs.pop("mm_token_type_ids", None) is not None: |
| mm_token_type_ids = torch.zeros_like(input_ids) |
| mm_token_type_ids[0][input_ids[0] == self.processor.image_token_id] = 1 |
| mm_token_type_ids[0][input_ids[0] == self.processor.video_token_id] = 2 |
| multi_modal_kwargs["mm_token_type_ids"] = mm_token_type_ids |
|
|
| |
| vision_position_ids, _ = self.processor.get_rope_index( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| **multi_modal_kwargs, |
| ) |
| vision_position_ids = vision_position_ids.transpose(0, 1) |
|
|
| valid_mask = attention_mask[0].bool() |
| text_position_ids = torch.ones((1, len(input_ids[0])), dtype=torch.long) |
| text_position_ids[0, valid_mask] = torch.arange(valid_mask.sum().item()) |
| text_position_ids = text_position_ids.unsqueeze(0) |
| position_ids = torch.cat((text_position_ids, vision_position_ids), dim=1) |
| return position_ids |
|
|
| async def _compute_score(self, output, prompts, responses, attention_mask, input_ids, position_ids, kwargs): |
| """Compute reward score for single sample.""" |
| enable_async_reward = self.reward_loop_worker_handles is not None |
|
|
| if output.reward_score is None and enable_async_reward: |
| batch = TensorDict( |
| { |
| "prompts": prompts, |
| "responses": responses, |
| "attention_mask": attention_mask, |
| "input_ids": input_ids, |
| "position_ids": position_ids, |
| }, |
| batch_size=1, |
| ) |
| non_tensor_batch = { |
| **{k: np.array([v]) for k, v in kwargs.items()}, |
| "__num_turns__": np.array([output.num_turns]), |
| "tool_extra_fields": np.array([output.extra_fields], dtype=object), |
| } |
|
|
| data = DataProto( |
| batch=batch, |
| non_tensor_batch=non_tensor_batch, |
| ) |
| selected_reward_loop_worker_handle = random.choice(self.reward_loop_worker_handles) |
| result = await selected_reward_loop_worker_handle.compute_score.remote(data) |
| output.reward_score = result["reward_score"] |
| output.extra_fields["reward_extra_info"] = result["reward_extra_info"] |
|
|
| def _postprocess( |
| self, |
| inputs: list[_InternalAgentLoopOutput], |
| input_non_tensor_batch: dict | None = None, |
| ) -> DataProto: |
| """Process the padded outputs from _run_agent_loop and combine them into a batch.""" |
| |
| prompt_ids = torch.cat([input.prompt_ids for input in inputs], dim=0) |
| response_ids = torch.cat([input.response_ids for input in inputs], dim=0) |
| response_mask = torch.cat([input.response_mask for input in inputs], dim=0) |
| attention_mask = torch.cat([input.attention_mask for input in inputs], dim=0) |
| input_ids = torch.cat([input.input_ids for input in inputs], dim=0) |
| position_ids = torch.cat([input.position_ids for input in inputs], dim=0) |
| optional_outputs = {} |
| if inputs[0].response_logprobs is not None: |
| optional_outputs["rollout_log_probs"] = torch.cat([input.response_logprobs for input in inputs], dim=0) |
| if inputs[0].routed_experts is not None: |
| optional_outputs["routed_experts"] = torch.cat([input.routed_experts for input in inputs], dim=0) |
|
|
| batch = TensorDict( |
| { |
| "prompts": prompt_ids, |
| "responses": response_ids, |
| "response_mask": response_mask, |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| |
| "position_ids": position_ids, |
| **optional_outputs, |
| }, |
| batch_size=len(inputs), |
| ) |
|
|
| scores = [input.reward_score for input in inputs] |
| if all(score is not None for score in scores): |
| prompt_length = prompt_ids.size(1) |
| response_length = attention_mask[:, prompt_length:].sum(dim=1) - 1 |
| rm_scores = torch.zeros_like(response_mask, dtype=torch.float32) |
| rm_scores[torch.arange(response_mask.size(0)), response_length] = torch.tensor(scores, dtype=torch.float32) |
| batch["rm_scores"] = rm_scores |
|
|
| non_tensor_batch = { |
| "__num_turns__": np.array([input.num_turns for input in inputs], dtype=np.int32), |
| } |
| if self.reward_loop_worker_handles is None and input_non_tensor_batch: |
| non_tensor_batch.update(input_non_tensor_batch) |
|
|
| |
| reward_extra_infos = [input.extra_fields.get("reward_extra_info", {}) for input in inputs] |
| reward_extra_keys = list(reward_extra_infos[0].keys()) |
| for key in reward_extra_keys: |
| non_tensor_batch[key] = np.array([info[key] for info in reward_extra_infos]) |
|
|
| |
| multi_modal_inputs_list = [input.multi_modal_inputs for input in inputs] |
| if any(mmi is not None for mmi in multi_modal_inputs_list): |
| non_tensor_batch["multi_modal_inputs"] = np.array(multi_modal_inputs_list, dtype=object) |
|
|
| metrics = [input.metrics.model_dump() for input in inputs] |
| |
| |
| extra_fields = {} |
| default_extra_keys = { |
| "turn_scores", |
| "tool_rewards", |
| "min_global_steps", |
| "max_global_steps", |
| "extras", |
| } |
| all_keys = set(key for input_item in inputs for key in input_item.extra_fields) | default_extra_keys |
| for key in all_keys: |
| temp_arr = np.empty(len(inputs), dtype=object) |
| temp_arr[:] = [input.extra_fields.get(key) for input in inputs] |
| extra_fields[key] = temp_arr |
|
|
| non_tensor_batch.update(extra_fields) |
|
|
| |
| |
| if "rm_scores" in batch.keys(): |
| meta_info = {"metrics": metrics, "reward_extra_keys": reward_extra_keys} |
| else: |
| meta_info = {"metrics": metrics} |
|
|
| return DataProto( |
| batch=batch, |
| non_tensor_batch=non_tensor_batch, |
| meta_info=meta_info, |
| ) |
|
|
|
|
| def _get_rollout_group_keys(non_tensor_batch: dict[str, np.ndarray]) -> list[Any]: |
| if "uid" in non_tensor_batch: |
| return non_tensor_batch["uid"].tolist() |
| if "index" in non_tensor_batch: |
| return non_tensor_batch["index"].tolist() |
| if not non_tensor_batch: |
| return [] |
| first_key = next(iter(non_tensor_batch)) |
| return list(range(len(non_tensor_batch[first_key]))) |
|
|
|
|
| def compute_rollout_n(group_keys: list[Any]) -> list[int]: |
| rollout_ns = [] |
| rollout_n = 0 |
| for i, group_key in enumerate(group_keys): |
| if i > 0 and group_keys[i - 1] == group_key: |
| rollout_n += 1 |
| else: |
| rollout_n = 0 |
| rollout_ns.append(rollout_n) |
| return rollout_ns |
|
|
|
|
| def group_consecutive_indices(group_keys: list[Any]) -> list[list[int]]: |
| if not group_keys: |
| return [] |
|
|
| groups = [] |
| current_group = [0] |
| for idx in range(1, len(group_keys)): |
| if group_keys[idx] == group_keys[idx - 1]: |
| current_group.append(idx) |
| else: |
| groups.append(current_group) |
| current_group = [idx] |
| groups.append(current_group) |
| return groups |
|
|
|
|
| def get_trajectory_info(step, sample_keys, validate, rollout_ns=None): |
| """Get trajectory info. |
| |
| Args: |
| step (int): global steps in the trainer. |
| sample_keys (list): stable group key for each prompt, usually uid or dataset index. |
| validate (bool): whether is a validate step. |
| rollout_ns (list, optional): precomputed rollout slot inside each repeated group. |
| |
| Returns: |
| list: trajectory. |
| """ |
| if rollout_ns is None: |
| rollout_ns = compute_rollout_n(sample_keys) |
|
|
| trajectory_info = [] |
| for i, sample_key in enumerate(sample_keys): |
| trajectory_info.append( |
| { |
| "step": step, |
| "sample_index": sample_key, |
| "rollout_n": int(rollout_ns[i]), |
| "validate": validate, |
| } |
| ) |
| return trajectory_info |
|
|
|
|
| class AgentLoopManager: |
| """Agent loop manager that manages a group of agent loop workers. |
| |
| - if worker_group is not None, rollout server is in hybrid mode, share GPUs with training engine. |
| - otherwise, rollout server is in standalone mode, use separate GPUs, e.g., one-step-off/fully async training. |
| |
| Args: |
| config (DictConfig): whole config for main entrypoint. |
| worker_group (RayWorkerGroup): ActorRolloutRef worker group for hybrid mode; None for standalone mode. |
| rollout_resource_pool (RayResourcePool): Resource pool for hybrid mode, only used by TensorRT-LLM. |
| reward_loop_worker_handles (List[ray.actor.ActorHandle]): Actor handles for streaming reward computation. |
| """ |
|
|
| def __init__( |
| self, |
| config: DictConfig, |
| worker_group: RayWorkerGroup = None, |
| rollout_resource_pool: RayResourcePool = None, |
| reward_loop_worker_handles: list[ray.actor.ActorHandle] = None, |
| ): |
| self.config = config |
| self.rollout_config, self.model_config = _get_rollout_and_model_config(config) |
| self.worker_group = worker_group |
| self.rollout_resource_pool = rollout_resource_pool |
| self.reward_loop_worker_handles = reward_loop_worker_handles |
|
|
| assert worker_group is not None or self.rollout_config.nnodes > 0, "nnodes must be > 0 in standalone mode" |
|
|
| |
| if not hasattr(self, "rollout_replica_class"): |
| self.rollout_replica_class = get_rollout_replica_class(self.rollout_config.name) |
| if not hasattr(self, "agent_loop_workers_class"): |
| self.agent_loop_workers_class = ray.remote(AgentLoopWorker) |
|
|
| @classmethod |
| @auto_await |
| async def create( |
| cls, |
| config: DictConfig, |
| worker_group: RayWorkerGroup = None, |
| rollout_resource_pool: RayResourcePool = None, |
| reward_loop_worker_handles: list[ray.actor.ActorHandle] = None, |
| ): |
| """Create agent loop manager.""" |
| instance = cls(config, worker_group, rollout_resource_pool, reward_loop_worker_handles) |
| await instance._initialize_llm_servers() |
| await instance._init_global_load_balancer() |
| await instance._init_agent_loop_workers() |
| return instance |
|
|
| async def _initialize_llm_servers(self): |
| rollout_world_size = ( |
| self.rollout_config.tensor_model_parallel_size |
| * self.rollout_config.data_parallel_size |
| * self.rollout_config.pipeline_model_parallel_size |
| ) |
| world_size = ( |
| self.worker_group.world_size |
| if self.worker_group |
| else self.rollout_config.n_gpus_per_node * self.rollout_config.nnodes |
| ) |
| num_replicas = world_size // rollout_world_size |
|
|
| self.rollout_replicas = [ |
| self.rollout_replica_class( |
| replica_rank=replica_rank, |
| config=self.rollout_config, |
| model_config=self.model_config, |
| gpus_per_node=self.rollout_config.n_gpus_per_node, |
| ) |
| for replica_rank in range(num_replicas) |
| ] |
|
|
| if self.worker_group and self.rollout_config.name != "trtllm": |
| await asyncio.gather(*[server.init_hybrid(self.worker_group) for server in self.rollout_replicas]) |
| |
| elif self.worker_group and self.rollout_config.name == "trtllm": |
| await asyncio.gather( |
| *[ |
| server.init_hybrid_colocated(self.worker_group, self.rollout_resource_pool) |
| for server in self.rollout_replicas |
| ] |
| ) |
| else: |
| await asyncio.gather(*[server.init_standalone() for server in self.rollout_replicas]) |
|
|
| self.server_handles = [server._server_handle for server in self.rollout_replicas] |
| self.server_addresses = [server._server_address for server in self.rollout_replicas] |
|
|
| print(f"AgentLoopManager: {self.server_addresses}") |
|
|
| |
| if self.rollout_config.prometheus.enable: |
| if self.rollout_config.disable_log_stats: |
| raise ValueError("PROMETHEUS needs disable_log_stats==False, but it is currently True.") |
| update_prometheus_config(self.rollout_config.prometheus, self.server_addresses, self.rollout_config.name) |
|
|
| async def _init_agent_loop_workers(self): |
| self.agent_loop_workers = [] |
| num_workers = self.rollout_config.agent.num_workers |
| load_balancer_handle = self.global_load_balancer |
| servers = list(zip(self.server_addresses, self.server_handles, strict=True)) |
|
|
| node_ids = [node["NodeID"] for node in ray.nodes() if node["Alive"] and node["Resources"].get("CPU", 0) > 0] |
| for i in range(num_workers): |
| |
| node_id = node_ids[i % len(node_ids)] |
| self.agent_loop_workers.append( |
| self.agent_loop_workers_class.options( |
| name=f"agent_loop_worker_{i}" + f"_{uuid4().hex[:8]}", |
| scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( |
| node_id=node_id, soft=True |
| ), |
| ).remote( |
| self.config, |
| servers, |
| load_balancer_handle, |
| self.reward_loop_worker_handles, |
| ) |
| ) |
|
|
| async def _init_global_load_balancer(self) -> None: |
| self.global_load_balancer = GlobalRequestLoadBalancer.remote( |
| server_actor_ids=self.server_addresses, |
| max_cache_size=DEFAULT_ROUTING_CACHE_SIZE, |
| ) |
|
|
| @auto_await |
| async def generate_sequences(self, prompts: DataProto) -> DataProto: |
| """Split input batch and dispatch to agent loop workers. |
| |
| Args: |
| prompts (DataProto): Input batch. |
| |
| Returns: |
| DataProto: Output batch. |
| """ |
|
|
| if "rollout_n" not in prompts.non_tensor_batch: |
| prompts.non_tensor_batch["rollout_n"] = np.array( |
| compute_rollout_n(_get_rollout_group_keys(prompts.non_tensor_batch)), |
| dtype=np.int32, |
| ) |
|
|
| chunkes = prompts.chunk(len(self.agent_loop_workers)) |
| outputs = await asyncio.gather( |
| *[ |
| worker.generate_sequences.remote(chunk) |
| for worker, chunk in zip(self.agent_loop_workers, chunkes, strict=True) |
| ] |
| ) |
| output = DataProto.concat(outputs) |
| output.non_tensor_batch.pop("rollout_n", None) |
|
|
| |
| metrics = [output.meta_info.pop("metrics") for output in outputs] |
| timing = self._performance_metrics(metrics, output) |
|
|
| output.meta_info = {"timing": timing, **outputs[0].meta_info} |
| return output |
|
|
| def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: DataProto) -> dict[str, float]: |
| timing = {} |
| t_generate_sequences = np.array([metric["generate_sequences"] for chunk in metrics for metric in chunk]) |
| t_tool_calls = np.array([metric["tool_calls"] for chunk in metrics for metric in chunk]) |
| num_preempted = np.array([metric["num_preempted"] for chunk in metrics for metric in chunk]) |
| arithmetic_sampling_used = np.array( |
| [metric.get("arithmetic_sampling_used", 0) for chunk in metrics for metric in chunk], |
| dtype=np.float32, |
| ) |
| arithmetic_sampling_avg_samples_per_group = np.array( |
| [metric.get("arithmetic_sampling_avg_samples_per_group", 0.0) for chunk in metrics for metric in chunk], |
| dtype=np.float32, |
| ) |
| arithmetic_sampling_groups_triggered_fraction = np.array( |
| [metric.get("arithmetic_sampling_groups_triggered_fraction", 0.0) for chunk in metrics for metric in chunk], |
| dtype=np.float32, |
| ) |
| timing["agent_loop/num_preempted/min"] = num_preempted.min() |
| timing["agent_loop/num_preempted/max"] = num_preempted.max() |
| timing["agent_loop/num_preempted/mean"] = num_preempted.mean() |
| timing["agent_loop/generate_sequences/min"] = t_generate_sequences.min() |
| timing["agent_loop/generate_sequences/max"] = t_generate_sequences.max() |
| timing["agent_loop/generate_sequences/mean"] = t_generate_sequences.mean() |
| timing["agent_loop/tool_calls/min"] = t_tool_calls.min() |
| timing["agent_loop/tool_calls/max"] = t_tool_calls.max() |
| timing["agent_loop/tool_calls/mean"] = t_tool_calls.mean() |
| timing["agent_loop/arithmetic_sampling/used_sample_mean"] = arithmetic_sampling_used.mean() |
| timing["agent_loop/arithmetic_sampling/used_sample_total"] = arithmetic_sampling_used.sum() |
| timing["agent_loop/arithmetic_sampling/avg_samples_per_group"] = ( |
| arithmetic_sampling_avg_samples_per_group.mean() |
| ) |
| timing["agent_loop/arithmetic_sampling/groups_triggered_fraction"] = ( |
| arithmetic_sampling_groups_triggered_fraction.mean() |
| ) |
|
|
| |
| slowest = np.argmax(t_generate_sequences + t_tool_calls) |
| attention_mask = output.batch["attention_mask"][slowest] |
| prompt_length = output.batch["prompts"].shape[1] |
| timing["agent_loop/slowest/generate_sequences"] = t_generate_sequences[slowest] |
| timing["agent_loop/slowest/tool_calls"] = t_tool_calls[slowest] |
| timing["agent_loop/slowest/prompt_length"] = attention_mask[:prompt_length].sum().item() |
| timing["agent_loop/slowest/response_length"] = attention_mask[prompt_length:].sum().item() |
| timing["agent_loop/slowest/num_preempted"] = num_preempted[slowest] |
|
|
| return timing |
|
|
| @auto_await |
| async def clear_kv_cache(self): |
| """Clear all rollout kv cache, but don`t sleep.""" |
| await asyncio.gather(*[replica.clear_kv_cache() for replica in self.rollout_replicas]) |
|
|
| @auto_await |
| async def start_profile(self, **kwargs): |
| """Start profiling on all rollout replicas.""" |
| await asyncio.gather(*[replica.start_profile(**kwargs) for replica in self.rollout_replicas]) |
|
|
| @auto_await |
| async def stop_profile(self): |
| """Stop profiling on all rollout replicas.""" |
| await asyncio.gather(*[replica.stop_profile() for replica in self.rollout_replicas]) |
|
|