shinka-backup / ccevolve /baselines /thetaevolve /slime_plugins /rollout_buffer /rollout_buffer_example.py
| import asyncio | |
| import time | |
| from typing import Any, Dict, List | |
| import aiohttp | |
| import requests | |
| import wandb | |
| from transformers import AutoTokenizer | |
| from slime.utils.async_utils import run | |
| from slime.utils.mask_utils import MultiTurnLossMaskGenerator | |
| from slime.utils.types import Sample | |
| __all__ = ["generate_rollout"] | |
| # Global variables for evaluation | |
| TOKENIZER = None | |
| START_ROLLOUT = True | |
| def select_rollout_data(args, results, need_length): | |
| """ | |
| Select the most recent groups when there are too many samples. | |
| Groups all samples by instance_id, sorts groups by timestamp. | |
| Args: | |
| args: Arguments containing configuration | |
| results: List of rollout data items with timestamps | |
| Returns: | |
| Selected samples from the newest groups based on timestamp cutoff | |
| """ | |
| if not results: | |
| return results | |
| # Group samples by instance_id | |
| groups = {} | |
| for item in results: | |
| assert "instance_id" in item, "instance_id must be in item" | |
| instance_id = item["instance_id"] | |
| if instance_id not in groups: | |
| groups[instance_id] = [] | |
| groups[instance_id].append(item) | |
| print(f"📊 Total groups: {len(groups)}, total samples: {len(results)}") | |
| # If we don't have too many samples, return all | |
| assert need_length < len(results), "need_length must be smaller than results length" | |
| # Get timestamp for each group (use the latest timestamp in the group) | |
| def get_group_timestamp(group_items): | |
| timestamps = [] | |
| for item in group_items: | |
| if "timestamp" in item: | |
| timestamps.append(float(item["timestamp"])) | |
| elif "extra_info" in item and "timestamp" in item["extra_info"]: | |
| timestamps.append(float(item["extra_info"]["timestamp"])) | |
| return max(timestamps) if timestamps else 0 | |
| # Create list of (group_id, timestamp, samples) and sort by timestamp | |
| group_data = [] | |
| for group_id, group_items in groups.items(): | |
| group_timestamp = get_group_timestamp(group_items) | |
| group_data.append((group_id, group_timestamp, group_items)) | |
| # Sort groups by timestamp (newest first) | |
| group_data.sort(key=lambda x: x[1], reverse=True) | |
| selected_groups = group_data[:need_length] | |
| # Flatten selected groups back to sample list | |
| selected_results = [] | |
| for group_id, timestamp, group_items in selected_groups: | |
| selected_results.append(group_items) | |
| # Statistics for monitoring | |
| if selected_groups: | |
| newest_ts = selected_groups[0][1] | |
| oldest_ts = selected_groups[-1][1] | |
| print( | |
| f"📈 Selected {len(selected_groups)} groups with {len(selected_results)*args.n_samples_per_prompt} samples" | |
| ) | |
| print(f"📈 Group timestamp range: {oldest_ts:.2f} to {newest_ts:.2f}") | |
| print(f"📈 Time span: {newest_ts - oldest_ts:.2f} seconds") | |
| return selected_results | |
| def log_raw_info(args, all_meta_info, rollout_id): | |
| final_meta_info = {} | |
| if all_meta_info: | |
| final_meta_info = { | |
| "total_samples": sum(meta["total_samples"] for meta in all_meta_info if "total_samples" in meta) | |
| } | |
| total_samples = final_meta_info["total_samples"] | |
| if total_samples > 0: | |
| weighted_reward_sum = sum( | |
| meta["avg_reward"] * meta["total_samples"] | |
| for meta in all_meta_info | |
| if "avg_reward" in meta and "total_samples" in meta | |
| ) | |
| final_meta_info.update( | |
| { | |
| "avg_reward": weighted_reward_sum / total_samples, | |
| } | |
| ) | |
| if hasattr(args, "use_wandb") and args.use_wandb: | |
| log_dict = { | |
| f"rollout/no_filter/total_samples": final_meta_info["total_samples"], | |
| f"rollout/no_filter/avg_reward": final_meta_info["avg_reward"], | |
| } | |
| try: | |
| if args.use_wandb: | |
| log_dict["rollout/step"] = ( | |
| rollout_id | |
| if not args.wandb_always_use_train_step | |
| else rollout_id | |
| * args.rollout_batch_size | |
| * args.n_samples_per_prompt | |
| // args.global_batch_size | |
| ) | |
| wandb.log(log_dict) | |
| if args.use_tensorboard: | |
| from slime.utils.tensorboard_utils import _TensorboardAdapter | |
| tb = _TensorboardAdapter(args) | |
| tb.log( | |
| data=log_dict, | |
| step=( | |
| rollout_id | |
| if not args.wandb_always_use_train_step | |
| else rollout_id | |
| * args.rollout_batch_size | |
| * args.n_samples_per_prompt | |
| // args.global_batch_size | |
| ), | |
| ) | |
| print(f"no filter rollout log {rollout_id}: {log_dict}") | |
| except Exception as e: | |
| print(f"Failed to log to wandb: {e}") | |
| print(f"no filter rollout log {rollout_id}: {final_meta_info}") | |
| else: | |
| print(f"no filter rollout log {rollout_id}: {final_meta_info}") | |
| async def get_rollout_data(api_base_url: str) -> tuple[List[Dict[str, Any]], Dict[str, Any]]: | |
| start_time = time.time() | |
| async with aiohttp.ClientSession() as session: | |
| while True: | |
| async with session.post( | |
| f"{api_base_url}/get_rollout_data", json={}, timeout=aiohttp.ClientTimeout(total=120) | |
| ) as response: | |
| response.raise_for_status() | |
| resp_json = await response.json() | |
| if resp_json["success"]: | |
| break | |
| await asyncio.sleep(3) | |
| if time.time() - start_time > 30: | |
| print("rollout data is not ready, have been waiting for 30 seconds") | |
| # Reset start_time to continue waiting or handle timeout differently | |
| start_time = time.time() # Or raise an exception, or return empty list | |
| data = resp_json["data"] | |
| meta_info = {} | |
| if isinstance(data, list): | |
| if "data" in data[0]: | |
| data = [item["data"] for item in data] | |
| elif isinstance(data, dict): | |
| if "data" in data: | |
| meta_info = data["meta_info"] | |
| data = data["data"] | |
| print(f"Meta info: {meta_info}") | |
| required_keys = {"uid", "instance_id", "messages", "reward", "extra_info"} | |
| for item in data: | |
| if not required_keys.issubset(item.keys()): | |
| raise ValueError(f"Missing required keys in response item: {item}") | |
| return data, meta_info | |
| def start_rollout(api_base_url: str, args, metadata): | |
| url = f"{api_base_url}/start_rollout" | |
| print(f"metadata: {metadata}") | |
| finished_groups_instance_id_list = [item for sublist in metadata.values() for item in sublist] | |
| payload = { | |
| "num_process": str(getattr(args, "rollout_num_process", 100)), | |
| "num_epoch": str(args.num_epoch or 3), | |
| "remote_engine_url": f"http://{args.sglang_router_ip}:{args.sglang_router_port}", | |
| "remote_buffer_url": args.rollout_buffer_url, | |
| "task_type": args.rollout_task_type, | |
| "input_file": args.prompt_data, | |
| "num_repeat_per_sample": str(args.n_samples_per_prompt), | |
| "max_tokens": str(args.rollout_max_response_len), | |
| "sampling_params": { | |
| "max_tokens": args.rollout_max_response_len, | |
| "temperature": args.rollout_temperature, | |
| "top_p": args.rollout_top_p, | |
| }, | |
| "tokenizer_path": args.hf_checkpoint, | |
| "skip_instance_ids": finished_groups_instance_id_list, | |
| } | |
| print("start rollout with payload: ", payload) | |
| while True: | |
| try: | |
| resp = requests.post(url, json=payload, timeout=10) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| print(f"[start_rollout] Success: {data}") | |
| return data | |
| except Exception as e: | |
| print(f"[start_rollout] Failed to send rollout config: {e}") | |
| async def generate_rollout_async(args, rollout_id: int, data_buffer, evaluation: bool = False) -> Dict[str, Any]: | |
| global START_ROLLOUT | |
| if evaluation: | |
| raise NotImplementedError("Evaluation rollout is not implemented") | |
| if START_ROLLOUT: | |
| metadata = data_buffer.get_metadata() | |
| start_inform = start_rollout(args.rollout_buffer_url, args, metadata) | |
| print(f"start rollout with payload: {start_inform}") | |
| print(f"start rollout id: {rollout_id}") | |
| START_ROLLOUT = False | |
| data_number_to_fetch = args.rollout_batch_size * args.n_samples_per_prompt - data_buffer.get_buffer_length() | |
| if data_number_to_fetch <= 0: | |
| print( | |
| f"❕buffer length: {data_buffer.get_buffer_length()}, buffer has enough data, return {args.rollout_batch_size} prompts" | |
| ) | |
| return data_buffer.get_samples(args.rollout_batch_size) | |
| assert ( | |
| data_number_to_fetch % args.n_samples_per_prompt == 0 | |
| ), "data_number_to_fetch must be a multiple of n_samples_per_prompt" | |
| print(f"INFO: buffer length: {data_buffer.get_buffer_length()}, data_number_to_fetch: {data_number_to_fetch}") | |
| base_url = args.rollout_buffer_url | |
| tokenizer = AutoTokenizer.from_pretrained(args.hf_checkpoint, trust_remote_code=True) | |
| retry_times = 0 | |
| results = [] | |
| all_meta_info = [] | |
| if args.fetch_trajectory_retry_times == -1: | |
| print( | |
| f"⚠️ [get_rollout_data] Fetch trajectory retry times set to -1, will retry indefinitely until sufficient data is collected" | |
| ) | |
| while args.fetch_trajectory_retry_times == -1 or retry_times < args.fetch_trajectory_retry_times: | |
| try: | |
| while len(results) < data_number_to_fetch: | |
| time.sleep(5) | |
| data, meta_info = await get_rollout_data(api_base_url=base_url) | |
| results.extend(data) | |
| if meta_info: | |
| all_meta_info.append(meta_info) | |
| print(f"get rollout data with length: {len(results)}") | |
| break | |
| except Exception as err: | |
| print(f"[get_rollout_data] Failed to get rollout data: {err}, retry times: {retry_times}") | |
| retry_times += 1 | |
| log_raw_info(args, all_meta_info, rollout_id) | |
| # Apply group-based data selection if there are too many samples | |
| results = select_rollout_data(args, results, data_number_to_fetch // args.n_samples_per_prompt) | |
| if len(all_meta_info) > 0 and "finished_groups" in all_meta_info[0]: | |
| finished_groups_instance_id_list = [] | |
| for item in all_meta_info: | |
| finished_groups_instance_id_list.extend(item["finished_groups"]) | |
| data_buffer.update_metadata({str(rollout_id): finished_groups_instance_id_list}) | |
| print("finally get rollout data with length: ", len(results)) | |
| sample_results = [] | |
| for i, group_record in enumerate(results): | |
| group_results = [] | |
| for record in group_record: | |
| oai_messages = record["messages"] | |
| mask_generator = MultiTurnLossMaskGenerator(tokenizer, tokenizer_type=args.loss_mask_type) | |
| token_ids, loss_mask = mask_generator.get_loss_mask(oai_messages) | |
| response_length = mask_generator.get_response_lengths([loss_mask])[0] | |
| loss_mask = loss_mask[-response_length:] | |
| group_results.append( | |
| Sample( | |
| index=record["instance_id"], | |
| prompt=record["uid"], | |
| tokens=token_ids, | |
| response_length=response_length, | |
| reward=record["reward"], | |
| status=( | |
| Sample.Status.COMPLETED | |
| if "finish_reason" not in record["extra_info"] | |
| or record["extra_info"]["finish_reason"] != "length" | |
| else Sample.Status.TRUNCATED | |
| ), | |
| loss_mask=loss_mask, | |
| metadata={**record["extra_info"]}, | |
| ) | |
| ) | |
| sample_results.append(group_results) | |
| data_buffer.add_samples(sample_results) | |
| final_return_results = data_buffer.get_samples(args.rollout_batch_size) # type: ignore | |
| return final_return_results | |
| def generate_rollout(args, rollout_id, data_buffer, evaluation=False): | |
| """Generate rollout for both training and evaluation.""" | |
| return run(generate_rollout_async(args, rollout_id, data_buffer, evaluation)) | |