import asyncio import queue import threading import time from typing import List # Import core functions from sglang_rollout directly to avoid code duplication from slime.rollout.sglang_rollout import GenerateState, generate_and_rm_group from slime.utils.async_utils import run from slime.utils.types import Sample # Global worker manager _global_worker = None _worker_lock = threading.Lock() def get_global_worker(args, data_buffer): """Get or create global worker""" global _global_worker with _worker_lock: if _global_worker is None or not _global_worker.worker_thread.is_alive(): print("Creating new global async worker...") _global_worker = AsyncRolloutWorker(args, data_buffer, concurrency=args.sglang_server_concurrency) _global_worker.start() return _global_worker def stop_global_worker(): """Stop global worker""" global _global_worker with _worker_lock: if _global_worker is not None: _global_worker.stop() _global_worker = None class AsyncRolloutWorker: """ Simplified asynchronous rollout worker, using threads instead of processes Supports continuous running, independent of rollout function lifecycle """ def __init__(self, args, data_buffer, concurrency=10): self.args = args self.data_buffer = data_buffer # Directly save data_buffer reference self.concurrency = concurrency self.running = True self.output_queue = queue.Queue(maxsize=1000) # Continuous output queue self.worker_thread = None self.state = GenerateState(args) async def continuous_worker_loop(self): """Continuous work loop - constantly get data from data_buffer and process""" print("Continuous async rollout worker started") active_tasks = set() max_concurrent_tasks = self.args.rollout_batch_size group_id_counter = 0 while self.running: try: # Clean up completed tasks if active_tasks: done_tasks = {task for task in active_tasks if task.done()} for task in done_tasks: try: task.result() # Results are already handled in callbacks except Exception as e: print(f"Task failed with exception: {e}") active_tasks -= done_tasks # If active task count hasn't reached limit, try to get new data and start tasks while len(active_tasks) < max_concurrent_tasks and self.running: samples = self.data_buffer.get_samples(1) for group in samples: group_id = group_id_counter group_id_counter += 1 # Create new async task task = asyncio.create_task( generate_and_rm_group( self.args, group, sampling_params=self.state.sampling_params.copy(), evaluation=False, ) ) # Add completion callback def make_callback(gid): def task_done_callback(task): result = task.result() self.output_queue.put((gid, result)) return task_done_callback task.add_done_callback(make_callback(group_id)) active_tasks.add(task) break # Brief sleep to avoid busy waiting await asyncio.sleep(1) except Exception as e: print(f"Error in continuous worker loop: {e}") await asyncio.sleep(1) if active_tasks: print(f"Waiting for {len(active_tasks)} continuous tasks to complete...") await asyncio.wait(active_tasks) print("Continuous async rollout worker stopped") def worker_thread_func(self): """Worker function running in independent thread""" asyncio.run(self.continuous_worker_loop()) def start(self): """Start continuous work mode""" if self.worker_thread is None or not self.worker_thread.is_alive(): self.worker_thread = threading.Thread(target=self.worker_thread_func, daemon=True) self.worker_thread.start() print("Started continuous async worker thread") def stop(self): """Stop worker thread""" self.running = False if self.worker_thread and self.worker_thread.is_alive(): self.worker_thread.join(timeout=5) print("Stopped async worker thread") def get_completed_groups(self) -> List[tuple]: """Get completed sample groups""" completed = [] while True: try: result = self.output_queue.get_nowait() completed.append(result) except queue.Empty: break return completed def get_queue_size(self) -> int: """Get current output queue size""" return self.output_queue.qsize() async def generate_rollout_async(args, rollout_id: int, data_buffer) -> List[List[Sample]]: """ Simplified asynchronous rollout generation - using global continuous worker """ assert args.rollout_global_dataset # Get global worker, which will run continuously worker = get_global_worker(args, data_buffer) # Simplified: directly use rollout_batch_size as target target_data_size = args.rollout_batch_size data = [] completed_groups = {} do_print = True print(f"Starting async rollout generation for {target_data_size} groups") print(f"Global worker queue size: {worker.get_queue_size()}") # Main loop: collect results from global worker's output queue start_time = time.time() last_progress_time = start_time no_progress_timeout = 30.0 # Warn if no progress for 30 seconds while len(data) < target_data_size: # Collect completed results completed = worker.get_completed_groups() made_progress = False for group_id, group in completed: completed_groups[group_id] = group made_progress = True if made_progress: last_progress_time = time.time() # Process completed groups in order (try to maintain order, but not strict requirement) processed_any = False # Process all available completed groups available_ids = list(completed_groups.keys()) for group_id in available_ids: if len(data) >= target_data_size: break group = completed_groups.pop(group_id) if do_print: print( f"First rollout sample: {[group[0].prompt + group[0].response]}, " f"label: {group[0].label}, reward: {group[0].reward}", flush=True, ) do_print = False # Simplified: directly add samples, no filters used data.append(group) processed_any = True # Check progress current_time = time.time() if current_time - last_progress_time > no_progress_timeout: print( f"Warning: No progress for {no_progress_timeout}s. " f"Queue size: {worker.get_queue_size()}, " f"Collected: {len(data)}/{target_data_size}" ) last_progress_time = current_time # If no results were processed, brief sleep to avoid busy waiting if not processed_any: await asyncio.sleep(0.01) duration = time.time() - start_time print(f"Rollout completed in {duration:.2f}s! Global worker queue size: {worker.get_queue_size()}") if data: print( f"Finish rollout: {[data[-1][0].prompt + data[-1][0].response]}, " f"label: {data[-1][0].label}, reward: {data[-1][0].reward}", flush=True, ) data = sorted(data, key=lambda group: group[0].index) return data def generate_rollout_fully_async(args, rollout_id, data_buffer, evaluation=False): if evaluation: raise ValueError("Evaluation mode not supported in simple async rollout") completed_samples = run(generate_rollout_async(args, rollout_id, data_buffer)) return completed_samples # Register exit cleanup function import atexit atexit.register(stop_global_worker)