File size: 8,610 Bytes
1faccd4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 | # Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
from collections import deque
from typing import Any
import ray
from omegaconf import DictConfig
logger = logging.getLogger(__name__)
@ray.remote(num_cpus=2, max_concurrency=20)
class MessageQueue:
"""
Simplified Ray-based asynchronous message queue for communication between Rollouter and Trainer
"""
def __init__(self, config: DictConfig, max_queue_size: int = 1000):
self.config = config
if max_queue_size is None:
raise ValueError(f"max_queue_size cannot be None, got: {max_queue_size}")
self.max_queue_size = int(max_queue_size)
self.queue = deque(maxlen=self.max_queue_size)
self.val_queue = deque()
# Asyncio for message handling
self.running = True
# async safe
self._lock = asyncio.Lock()
self._consumer_condition = asyncio.Condition(self._lock)
# statistic message
self.total_produced = 0
self.total_consumed = 0
self.dropped_samples = 0
print(f"[MessageQueue] initialized with max_queue_size={max_queue_size}")
async def put_sample(self, sample: Any) -> bool:
"""
Put a batch sample into the queue
Args:
sample: Sample data
Returns:
bool: Whether the sample was successfully put into the queue
"""
async with self._lock:
# If queue is full, remove the oldest sample (rarely happens)
is_drop = False
if len(self.queue) >= self.max_queue_size:
self.queue.popleft()
self.dropped_samples += 1
is_drop = True
logger.warning("Queue full, dropped sample")
self.queue.append(sample)
self.total_produced += 1
# Notify waiting consumers
self._consumer_condition.notify_all()
if self.total_produced % 100 == 0:
print(f"MessageQueue stats: produced={self.total_produced}, queue_size={len(self.queue)}")
if is_drop:
return False
return True
async def get_sample(self) -> Any | None:
"""
Get a single sample from the queue, wait until one is available
Returns:
Any: Single sample data or None if queue is closed
"""
async with self._lock:
while len(self.queue) == 0 and self.running:
await self._consumer_condition.wait()
# If queue is closed and empty, return None
if not self.running and len(self.queue) == 0:
return None
# Get one sample
data = self.queue.popleft()
self.total_consumed += 1
return data, len(self.queue)
async def get_queue_size(self) -> int:
"""Get current queue length"""
async with self._lock:
return len(self.queue)
async def get_statistics(self) -> dict[str, Any]:
"""Get queue statistics"""
async with self._lock:
return {
"queue_size": len(self.queue),
"total_produced": self.total_produced,
"total_consumed": self.total_consumed,
"dropped_samples": self.dropped_samples,
"max_queue_size": self.max_queue_size,
}
async def clear_queue(self):
"""Clear the queue"""
async with self._lock:
cleared_count = len(self.queue)
self.queue.clear()
logger.info(f"Cleared {cleared_count} samples from queue")
async def shutdown(self):
"""Shutdown the message queue"""
async with self._lock:
self.running = False
# Notify all waiting coroutines so they can exit
self._consumer_condition.notify_all()
logger.info("MessageQueue shutdown")
async def get_memory_usage(self) -> dict:
"""Get memory usage statistics"""
async with self._lock:
# Estimate memory usage of samples in queue
import sys
total_size = 0
sample_count = len(self.queue)
if sample_count > 0:
# Estimate size of a single sample (simplified estimation)
sample = list(self.queue)[0]
try:
sample_size = sys.getsizeof(sample)
# Since we now store RolloutSample directly, estimate based on its components
if hasattr(sample, "original_batch_dict") and sample.original_batch_dict:
# Estimate batch data size
batch_data = sample.original_batch_dict.get("batch", {})
sample_size += len(batch_data) * 1000 # Roughly estimate 1KB per batch entry
if hasattr(sample, "agent_loop_output"):
# Estimate AgentLoopOutput size
sample_size += 5000 # Roughly estimate 5KB for AgentLoopOutput
total_size = sample_size * sample_count
except Exception:
total_size = sample_count * 15000 # Roughly estimate 15KB per RolloutSample
return {
"queue_samples": sample_count,
"estimated_memory_bytes": total_size,
"estimated_memory_mb": total_size / (1024 * 1024),
}
async def put_validate(self, data):
async with self._lock:
self.val_queue.append(data)
async def get_validate(self):
async with self._lock:
if self.val_queue:
return self.val_queue.popleft()
else:
return None
class MessageQueueClient:
"""Asyncio-compatible MessageQueue client for communicating with MessageQueue Actor"""
def __init__(self, queue_actor: Any):
self.queue_actor = queue_actor
async def put_sample(self, sample: Any) -> bool:
"""Put batch into queue (async)"""
future = self.queue_actor.put_sample.remote(sample)
return await asyncio.wrap_future(future.future())
async def put_validate(self, data: Any) -> bool:
future = self.queue_actor.put_validate.remote(data)
return await asyncio.wrap_future(future.future())
def get_validate_sync(self) -> Any | None:
return ray.get(self.queue_actor.get_validate.remote())
async def get_sample(self) -> Any | None:
"""Get single sample from queue, wait until one is available (async)"""
future = self.queue_actor.get_sample.remote()
return await asyncio.wrap_future(future.future())
async def get_queue_size(self) -> int:
"""Get queue size (async)"""
future = self.queue_actor.get_queue_size.remote()
return await asyncio.wrap_future(future.future())
async def get_statistics(self) -> dict[str, Any]:
"""Get statistics (async)"""
future = self.queue_actor.get_statistics.remote()
return await asyncio.wrap_future(future.future())
async def clear_queue(self):
"""Clear queue (async)"""
future = self.queue_actor.clear_queue.remote()
await asyncio.wrap_future(future.future())
async def shutdown(self):
"""Shutdown queue (async)"""
future = self.queue_actor.shutdown.remote()
await asyncio.wrap_future(future.future())
async def get_memory_usage(self) -> dict:
"""Get memory usage statistics (async)"""
future = self.queue_actor.get_memory_usage.remote()
return await asyncio.wrap_future(future.future())
def get_sample_sync(self) -> Any | None:
"""Get single sample from queue (sync - deprecated, use get_sample instead)"""
return ray.get(self.queue_actor.get_sample.remote())
def get_statistics_sync(self) -> dict[str, Any]:
"""Get statistics (sync - deprecated, use get_statistics instead)"""
return ray.get(self.queue_actor.get_statistics.remote())
|