File size: 6,495 Bytes
1faccd4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | # Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import os
from typing import Any, Optional
import ray
import torch
from omegaconf import DictConfig
from verl.experimental.agent_loop.agent_loop import (
AgentLoopManager,
AgentLoopWorker,
AsyncLLMServerManager,
TokenOutput,
)
from verl.protocol import DataProto
from verl.single_controller.ray import RayResourcePool, RayWorkerGroup
from verl.utils.ray_utils import auto_await
from verl.utils.rollout_trace import (
rollout_trace_op,
)
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
class FullyAsyncLLMServerManager(AsyncLLMServerManager):
"""FullyAsyncLLMServerManager supports resume generation on partial rollout, making rollout interruption
invisible to the AgentLoop.
"""
@rollout_trace_op
async def generate(
self,
request_id,
*,
prompt_ids: list[int],
sampling_params: dict[str, Any],
image_data: Optional[list[Any]] = None,
video_data: Optional[list[Any]] = None,
) -> TokenOutput:
"""Generate tokens from prompt ids.
Args:
request_id (str): request id for sticky session.
prompt_ids (List[int]): List of prompt token ids.
sampling_params (Dict[str, Any]): Sampling parameters for the chat completion.
image_data (Optional[List[Any]]): Image data for the chat completion.
video_data (Optional[List[Any]]): Video data for the chat completion.
Returns:
TokenOutput: token output
"""
limit_key = None
if "max_tokens" in sampling_params:
limit_key = "max_tokens"
elif "max_new_tokens" in sampling_params:
limit_key = "max_new_tokens"
original_max_tokens = sampling_params.get(limit_key) if limit_key else None
final_output = TokenOutput(
token_ids=[],
log_probs=[],
num_preempted=0,
)
min_global_steps, max_global_steps = None, None
while True:
# 1. generate tokens
output = await super().generate(
request_id=request_id,
prompt_ids=prompt_ids + final_output.token_ids,
sampling_params=sampling_params,
image_data=image_data,
video_data=video_data,
)
# 2. merge output into final_output
final_output.token_ids.extend(output.token_ids)
if output.log_probs is not None:
final_output.log_probs.extend(output.log_probs)
if output.routed_experts is not None:
if final_output.routed_experts is None:
final_output.routed_experts = output.routed_experts
else:
final_output.routed_experts = torch.cat([final_output.routed_experts, output.routed_experts], dim=0)
if output.num_preempted is not None:
final_output.num_preempted += output.num_preempted
final_output.stop_reason = output.stop_reason
# update model weights version
global_steps = output.extra_fields.get("global_steps", None)
if min_global_steps is None:
min_global_steps = global_steps
max_global_steps = global_steps
# 3. update max_new_tokens
if original_max_tokens is not None:
sampling_params[limit_key] = original_max_tokens - len(final_output.token_ids)
if len(final_output.token_ids) >= original_max_tokens:
final_output.stop_reason = "length"
break
# 4. check stop reason
if output.stop_reason not in ("aborted", "abort") or not self.config.async_training.partial_rollout:
break
final_output.extra_fields["global_steps"] = global_steps
final_output.extra_fields["min_global_steps"] = min_global_steps
final_output.extra_fields["max_global_steps"] = max_global_steps
return final_output
@ray.remote
class FullyAsyncAgentLoopWorker(AgentLoopWorker):
def __init__(
self,
config: DictConfig,
servers: list[tuple[str, ray.actor.ActorHandle]],
load_balancer_handle: ray.actor.ActorHandle,
reward_loop_worker_handles: list[ray.actor.ActorHandle] = None,
):
self.server_manager = FullyAsyncLLMServerManager(config, servers, load_balancer_handle)
super().__init__(config, servers, load_balancer_handle, reward_loop_worker_handles)
class FullyAsyncAgentLoopManager(AgentLoopManager):
def __init__(
self,
config: DictConfig,
worker_group: RayWorkerGroup = None,
rollout_resource_pool: RayResourcePool = None,
reward_loop_worker_handles: list[ray.actor.ActorHandle] = None,
):
self.agent_loop_workers_class = FullyAsyncAgentLoopWorker
super().__init__(config, worker_group, rollout_resource_pool, reward_loop_worker_handles)
@auto_await
async def generate_sequences_single(self, prompts: DataProto) -> DataProto:
"""Split input batch and dispatch to agent loop workers.
Args:
prompts (DataProto): Input batch. Single sample data
Returns:
DataProto: Output batch.
"""
worker = self._select_best_worker()
output_future = worker.generate_sequences.remote(prompts)
return await asyncio.wrap_future(output_future.future())
def _select_best_worker(self):
"""Select the best worker, simple round-robin load balancing"""
if not hasattr(self, "_worker_index"):
self._worker_index = 0
worker = self.agent_loop_workers[self._worker_index]
self._worker_index = (self._worker_index + 1) % len(self.agent_loop_workers)
return worker
|