| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | Implement base data transfer protocol between any two functions, modules. |
| | We can subclass Protocol to define more detailed batch info with specific keys |
| | """ |
| |
|
| | import contextlib |
| | import copy |
| | import logging |
| | import math |
| | import os |
| | import pickle |
| | from dataclasses import dataclass, field |
| | from typing import Any, Callable, Optional |
| |
|
| | import numpy as np |
| | import ray |
| | import tensordict |
| | import torch |
| | import torch.distributed |
| | from packaging import version |
| | from packaging.version import parse as parse_version |
| | from tensordict import TensorDict |
| | from torch.utils.data import DataLoader |
| |
|
| | from verl.utils.device import get_device_id, get_torch_device |
| | from verl.utils.py_functional import union_two_dict |
| | from verl.utils.torch_functional import allgather_dict_tensors |
| |
|
| | __all__ = ["DataProto", "union_tensor_dict"] |
| |
|
| | with contextlib.suppress(Exception): |
| | tensordict.set_lazy_legacy(False).set() |
| | if parse_version(tensordict.__version__) < parse_version("0.10.0"): |
| | tensordict.set_list_to_stack(True).set() |
| |
|
| |
|
| | class _DataProtoConfigMeta(type): |
| | _config = {} |
| |
|
| | auto_padding_key = "_verl_auto_padding" |
| |
|
| | @property |
| | def auto_padding(cls): |
| | enabled_by_env = os.getenv("VERL_AUTO_PADDING", "FALSE").upper() in ["TRUE", "1"] |
| | return enabled_by_env or cls._config.get(cls.auto_padding_key, False) |
| |
|
| | @auto_padding.setter |
| | def auto_padding(cls, enabled: bool): |
| | assert isinstance(enabled, bool), f"enabled must be a boolean, got {enabled} as {type(enabled)}" |
| | cls._config[cls.auto_padding_key] = enabled |
| |
|
| |
|
| | class DataProtoConfig(metaclass=_DataProtoConfigMeta): |
| | pass |
| |
|
| |
|
| | _padding_size_key = "_padding_size_key_x123d" |
| |
|
| |
|
| | def pad_dataproto_to_divisor(data: "DataProto", size_divisor: int): |
| | """Pad a DataProto to size divisible by size_divisor |
| | |
| | Args: |
| | size_divisor (int): size divisor |
| | |
| | Returns: |
| | data: (DataProto): the padded DataProto |
| | pad_size (int) |
| | """ |
| | assert isinstance(data, DataProto), "data must be a DataProto" |
| | if len(data) % size_divisor != 0: |
| | pad_size = size_divisor - len(data) % size_divisor |
| | padding_protos = [] |
| | remaining_pad = pad_size |
| | while remaining_pad > 0: |
| | take_size = min(remaining_pad, len(data)) |
| | padding_protos.append(data[:take_size]) |
| | remaining_pad -= take_size |
| | data_padded = DataProto.concat([data] + padding_protos) |
| | else: |
| | if len(data) == 0: |
| | logging.warning("padding a DataProto with no item, no changed made") |
| | pad_size = 0 |
| | data_padded = data |
| | return data_padded, pad_size |
| |
|
| |
|
| | def unpad_dataproto(data: "DataProto", pad_size): |
| | """Unpad the data proto with pad_size. i.e. `data[:-pad_size]`""" |
| | if pad_size != 0: |
| | data = data[:-pad_size] |
| | return data |
| |
|
| |
|
| | def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict: |
| | """Union two tensordicts.""" |
| | assert tensor_dict1.batch_size == tensor_dict2.batch_size, ( |
| | f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}" |
| | ) |
| | for key in tensor_dict2.keys(): |
| | if key not in tensor_dict1.keys(): |
| | tensor_dict1[key] = tensor_dict2[key] |
| | else: |
| | assert tensor_dict1[key].equal(tensor_dict2[key]), ( |
| | f"{key} in tensor_dict1 and tensor_dict2 are not the same object" |
| | ) |
| |
|
| | return tensor_dict1 |
| |
|
| |
|
| | def _array_equal(array1: np.ndarray, array2: np.ndarray, visited: set[int]) -> bool: |
| | """ |
| | Recursively compares two NumPy arrays for strict equality, with special |
| | handling for object-dtype arrays, NaN values, and circular references. |
| | This function assumes that the two arguments provided are NumPy arrays. |
| | |
| | Args: |
| | array1: The first NumPy array. |
| | array2: The second NumPy array. |
| | |
| | Returns: |
| | True if the arrays' dtypes, shapes, and all elements are equal. |
| | """ |
| | |
| | if array1.dtype != array2.dtype or array1.shape != array2.shape: |
| | return False |
| |
|
| | |
| | if array1.dtype != "object": |
| | return np.array_equal(array1, array2, equal_nan=True) |
| |
|
| | |
| | |
| | |
| | return all(_deep_equal(x, y, visited) for x, y in zip(array1.flat, array2.flat, strict=False)) |
| |
|
| |
|
| | def _deep_equal(a: Any, b: Any, visited: set[int]) -> bool: |
| | """ |
| | Recursively performs a deep comparison between two Python objects. |
| | - Handles NaN values correctly (NaN == NaN evaluates to True). |
| | - Handling circular references. |
| | - Dispatches to _array_equal if both objects are NumPy arrays. |
| | - Otherwise, uses standard '==' comparison. |
| | """ |
| | if type(a) is not type(b): |
| | return False |
| |
|
| | |
| | |
| | |
| | obj_id = id(a) |
| | if obj_id in visited: |
| | return True |
| |
|
| | visited.add(obj_id) |
| |
|
| | |
| | result = False |
| | if isinstance(a, float) and math.isnan(a) and math.isnan(b): |
| | result = True |
| | elif isinstance(a, np.ndarray): |
| | |
| | result = _array_equal(a, b, visited) |
| | else: |
| | |
| | result = a == b |
| |
|
| | |
| | visited.remove(obj_id) |
| | return result |
| |
|
| |
|
| | def union_numpy_dict(tensor_dict1: dict[str, np.ndarray], tensor_dict2: dict[str, np.ndarray]) -> dict[str, np.ndarray]: |
| | for key, val in tensor_dict2.items(): |
| | if key in tensor_dict1: |
| | assert isinstance(tensor_dict2[key], np.ndarray) |
| | assert isinstance(tensor_dict1[key], np.ndarray) |
| | |
| | assert _deep_equal(tensor_dict1[key], tensor_dict2[key], visited=set()), ( |
| | f"`{key}` in tensor_dict1 and tensor_dict2 are not the same object." |
| | ) |
| | tensor_dict1[key] = val |
| |
|
| | return tensor_dict1 |
| |
|
| |
|
| | def list_of_dict_to_dict_of_list(list_of_dict: list[dict]): |
| | if len(list_of_dict) == 0: |
| | return {} |
| | keys = list_of_dict[0].keys() |
| | output = {key: [] for key in keys} |
| | for data in list_of_dict: |
| | for key, item in data.items(): |
| | assert key in output |
| | output[key].append(item) |
| | return output |
| |
|
| |
|
| | def fold_batch_dim(data: "DataProto", new_batch_size): |
| | """ |
| | Fold a batch dim from [bsz, xxx] into [new_bsz, bsz // new_bsz, xxx] |
| | """ |
| | batch_size = data.batch.batch_size[0] |
| |
|
| | assert batch_size % new_batch_size == 0 |
| |
|
| | tensor: TensorDict = data.batch |
| | non_tensor = data.non_tensor_batch |
| |
|
| | tensor = tensor.view(new_batch_size, -1) |
| | tensor.auto_batch_size_(batch_dims=1) |
| |
|
| | for key, val in non_tensor.items(): |
| | non_tensor[key] = np.reshape(val, newshape=(new_batch_size, -1, *val.shape[1:])) |
| |
|
| | return type(data)(batch=tensor, non_tensor_batch=non_tensor, meta_info=data.meta_info) |
| |
|
| |
|
| | def unfold_batch_dim(data: "DataProto", batch_dims=2): |
| | """ |
| | Unfold the first n dims as new batch dim |
| | """ |
| | tensor: TensorDict = data.batch |
| | non_tensor = data.non_tensor_batch |
| | tensor.auto_batch_size_(batch_dims=batch_dims) |
| | tensor = tensor.view(-1) |
| |
|
| | batch_size = tensor.batch_size[0] |
| |
|
| | non_tensor_new = {} |
| |
|
| | for key, val in non_tensor.items(): |
| | non_tensor_new[key] = np.reshape(val, newshape=(batch_size, *val.shape[batch_dims:])) |
| |
|
| | return type(data)(batch=tensor, non_tensor_batch=non_tensor_new, meta_info=data.meta_info) |
| |
|
| |
|
| | def serialize_single_tensor(obj: torch.Tensor) -> tuple[str, tuple[int, ...], int | memoryview]: |
| | data = obj.flatten().contiguous().view(torch.uint8).numpy() |
| | dtype = str(obj.dtype).removeprefix("torch.") |
| | return dtype, obj.shape, data |
| |
|
| |
|
| | def serialize_tensordict(batch: TensorDict) -> tuple[tuple[int, ...], Optional[str], dict[str, tuple[str, Any]]]: |
| | encoded_items: dict[str, tuple[Any]] = {} |
| | for k, v in batch.items(): |
| | if not v.is_nested: |
| | encoded_items[k] = serialize_single_tensor(v) |
| | else: |
| | layout = str(v.layout).removeprefix("torch.") |
| | data = [serialize_single_tensor(tensor) for tensor in v.unbind()] |
| | encoded_items[k] = (layout, data) |
| |
|
| | batch_size = tuple(batch.batch_size) |
| | device = str(batch.device) if batch.device is not None else None |
| | return batch_size, device, encoded_items |
| |
|
| |
|
| | def deserialize_single_tensor(arr: Any) -> torch.Tensor: |
| | dtype, shape, data = arr |
| |
|
| | torch_dtype = getattr(torch, dtype) |
| | assert isinstance(torch_dtype, torch.dtype) |
| |
|
| | buffer = bytearray(data) |
| | |
| | arr = torch.frombuffer(buffer, dtype=torch.uint8) |
| | |
| | return arr.view(torch_dtype).view(shape) |
| |
|
| |
|
| | def deserialize_tensordict(arr: Any) -> TensorDict: |
| | batch_size, device, encoded_items = arr |
| | decoded_items: dict[str, Any] = {} |
| |
|
| | for k, v in encoded_items.items(): |
| | if len(v) == 3: |
| | |
| | decoded_items[k] = deserialize_single_tensor(v) |
| | elif len(v) == 2: |
| | |
| | layout, data = v |
| | torch_layout = getattr(torch, layout) |
| | decoded_items[k] = torch.nested.as_nested_tensor( |
| | [deserialize_single_tensor(tensor) for tensor in data], layout=torch_layout |
| | ) |
| | else: |
| | raise ValueError(f"Invalid tensor encoding format, expected length 2 or 3, got {len(v)}") |
| |
|
| | return TensorDict(source=decoded_items, batch_size=batch_size, device=device) |
| |
|
| |
|
| | def collate_fn(x: list["DataProtoItem"]): |
| | batch = [] |
| | non_tensor_batch = [] |
| | for data in x: |
| | batch.append(data.batch) |
| | non_tensor_batch.append(data.non_tensor_batch) |
| | batch = torch.stack(batch).contiguous() |
| | non_tensor_batch = list_of_dict_to_dict_of_list(non_tensor_batch) |
| | for key, val in non_tensor_batch.items(): |
| | non_tensor_batch[key] = np.array(val, dtype=object) |
| | return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) |
| |
|
| |
|
| | @dataclass |
| | class DataProtoItem: |
| | |
| | batch: TensorDict = None |
| | non_tensor_batch: dict = field(default_factory=dict) |
| | meta_info: dict = field(default_factory=dict) |
| |
|
| |
|
| | @dataclass |
| | class DataProto: |
| | """ |
| | A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions. |
| | It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/. |
| | TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the |
| | same batch size should be put inside batch. |
| | """ |
| |
|
| | batch: TensorDict = None |
| | non_tensor_batch: dict = field(default_factory=dict) |
| | meta_info: dict = field(default_factory=dict) |
| |
|
| | def __post_init__(self): |
| | |
| | self.check_consistency() |
| |
|
| | def __len__(self): |
| | if self.batch is not None: |
| | return self.batch.batch_size[0] |
| | elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0: |
| | random_key = list(self.non_tensor_batch.keys())[0] |
| | return self.non_tensor_batch[random_key].shape[0] |
| | else: |
| | return 0 |
| |
|
| | def __getitem__(self, item): |
| | """ |
| | Enhanced indexing for DataProto objects. |
| | |
| | Args: |
| | item: Can be one of: |
| | - int: A single index |
| | - slice: A slice object (start:stop:step) |
| | - list: A list of indices |
| | - numpy.ndarray: An array of indices |
| | - torch.Tensor: A tensor of indices |
| | |
| | Returns: |
| | DataProto: For all indexing types except single integers |
| | DataProtoItem: Only for single integer indices |
| | """ |
| | |
| | if isinstance(item, slice): |
| | return self.slice(item.start, item.stop, item.step) |
| |
|
| | |
| | elif isinstance(item, list | np.ndarray | torch.Tensor): |
| | return self.select_idxs(item) |
| |
|
| | |
| | elif isinstance(item, int | np.integer): |
| | tensor_data = self.batch[item] if self.batch is not None else None |
| | non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()} |
| | return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info) |
| |
|
| | |
| | else: |
| | raise TypeError(f"Indexing with {type(item)} is not supported") |
| |
|
| | def __getstate__(self): |
| | if version.parse(tensordict.__version__) >= version.parse("0.5.0") and self.batch is not None: |
| | |
| | if len(self.batch.keys()) > 0: |
| | batch = self.batch.contiguous().consolidate() |
| | else: |
| | batch = self.batch |
| | else: |
| | batch = self.batch |
| |
|
| | if os.getenv("VERL_DATAPROTO_SERIALIZATION_METHOD") == "numpy": |
| | if batch is not None: |
| | batch = serialize_tensordict(self.batch) |
| |
|
| | return ( |
| | batch, |
| | self.non_tensor_batch, |
| | self.meta_info, |
| | ) |
| | else: |
| | import io |
| |
|
| | buffer = io.BytesIO() |
| | torch.save(batch, buffer) |
| | buffer_bytes = buffer.getvalue() |
| | return buffer_bytes, self.non_tensor_batch, self.meta_info |
| |
|
| | def __setstate__(self, data): |
| | batch_deserialized_bytes, non_tensor_batch, meta_info = data |
| |
|
| | if os.getenv("VERL_DATAPROTO_SERIALIZATION_METHOD") == "numpy": |
| | if batch_deserialized_bytes is not None: |
| | self.batch = deserialize_tensordict(batch_deserialized_bytes) |
| | else: |
| | self.batch = None |
| | else: |
| | import io |
| |
|
| | batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes) |
| | batch = torch.load( |
| | batch_deserialized, |
| | weights_only=False, |
| | map_location="cpu" if not get_torch_device().is_available() else None, |
| | ) |
| | self.batch = batch |
| |
|
| | self.non_tensor_batch = non_tensor_batch |
| | self.meta_info = meta_info |
| |
|
| | def save_to_disk(self, filepath): |
| | with open(filepath, "wb") as f: |
| | pickle.dump(self, f) |
| |
|
| | @staticmethod |
| | def load_from_disk(filepath) -> "DataProto": |
| | with open(filepath, "rb") as f: |
| | data = pickle.load(f) |
| | return data |
| |
|
| | def print_size(self, prefix=""): |
| | size_of_tensordict = 0 |
| | if self.batch is not None: |
| | for _, tensor in self.batch.items(): |
| | size_of_tensordict += tensor.element_size() * tensor.numel() |
| | size_of_numpy_array = 0 |
| | for _, numpy_array in self.non_tensor_batch.items(): |
| | size_of_numpy_array += numpy_array.nbytes |
| |
|
| | size_of_numpy_array /= 1024**3 |
| | size_of_tensordict /= 1024**3 |
| |
|
| | message = f"Size of tensordict: {size_of_tensordict} GB, size of non_tensor_batch: {size_of_numpy_array} GB" |
| |
|
| | if prefix: |
| | message = f"{prefix}, " + message |
| | print(message) |
| |
|
| | def check_consistency(self): |
| | """Check the consistency of the DataProto. Mainly for batch and non_tensor_batch |
| | We expose this function as a public one so that user can call themselves directly |
| | """ |
| | if self.batch is not None: |
| | assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1" |
| |
|
| | if self.non_tensor_batch is not None: |
| | for key, val in self.non_tensor_batch.items(): |
| | assert isinstance(val, np.ndarray) |
| |
|
| | if self.batch is not None and self.non_tensor_batch is not None and len(self.non_tensor_batch) != 0: |
| | |
| | assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1 when non_tensor_batch is not empty." |
| |
|
| | batch_size = self.batch.batch_size[0] |
| | for key, val in self.non_tensor_batch.items(): |
| | assert isinstance(val, np.ndarray), ( |
| | f"data in the non_tensor_batch must be a numpy.array with dtype=object, but for " |
| | f"{key=}, got {type(val)=}" |
| | ) |
| | assert val.shape[0] == batch_size, ( |
| | f"key {key} length {len(val)} is not equal to batch size {batch_size}" |
| | ) |
| |
|
| | @classmethod |
| | def from_single_dict(cls, data: dict[str, torch.Tensor | np.ndarray], meta_info=None, auto_padding=False): |
| | """Create a DataProto from a dict of tensors and non_tensors""" |
| | tensors = {} |
| | non_tensors = {} |
| |
|
| | for key, val in data.items(): |
| | if isinstance(val, torch.Tensor): |
| | tensors[key] = val |
| | elif isinstance(val, np.ndarray): |
| | non_tensors[key] = val |
| | else: |
| | raise ValueError(f"Unsupported type in data {type(val)}") |
| |
|
| | return cls.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info, auto_padding=auto_padding) |
| |
|
| | @classmethod |
| | def from_dict( |
| | cls, |
| | tensors: Optional[dict[str, torch.Tensor]] = None, |
| | non_tensors=None, |
| | meta_info=None, |
| | num_batch_dims=1, |
| | auto_padding=False, |
| | ): |
| | """Create a DataProto from a dict of tensors. This assumes that |
| | 1. All the tensor in tensors have the same dim0 |
| | 2. Only dim0 is the batch dim |
| | """ |
| |
|
| | assert num_batch_dims > 0, "num_batch_dims must be greater than zero" |
| | if non_tensors is not None: |
| | assert num_batch_dims == 1, "only support num_batch_dims=1 when non_tensors is not None." |
| |
|
| | if tensors is None: |
| | tensors = {} |
| | if meta_info is None: |
| | meta_info = {} |
| | if non_tensors is None: |
| | non_tensors = {} |
| |
|
| | assert isinstance(non_tensors, dict) |
| |
|
| | |
| | batch_size = None |
| | pivot_key = None |
| | for key, tensor in tensors.items(): |
| | if batch_size is None: |
| | batch_size = tensor.shape[:num_batch_dims] |
| | pivot_key = key |
| | else: |
| | current_batch = tensor.shape[:num_batch_dims] |
| | assert batch_size == current_batch, ( |
| | f"Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. " |
| | f"Got {pivot_key} has {batch_size}, {key} has {current_batch}" |
| | ) |
| |
|
| | for key, val in non_tensors.items(): |
| | if not isinstance(val, np.ndarray): |
| | non_tensors[key] = np.array(val, dtype=object) |
| |
|
| | tensor_dict = TensorDict(source=tensors, batch_size=batch_size) if tensors else None |
| | if auto_padding: |
| | meta_info[DataProtoConfig.auto_padding_key] = True |
| | return cls(batch=tensor_dict, non_tensor_batch=non_tensors, meta_info=meta_info) |
| |
|
| | @classmethod |
| | def from_tensordict( |
| | cls, |
| | tensor_dict: TensorDict = None, |
| | meta_info=None, |
| | num_batch_dims=1, |
| | ): |
| | """Create a DataProto from a TensorDict. This assumes that |
| | 1. All the tensor in tensor_dict have the same dim0 |
| | 2. Only dim0 is the batch dim |
| | """ |
| | assert version.parse(tensordict.__version__) >= version.parse("0.10.0"), ( |
| | "Build DataProto from TensorDict at least requires tensordict version 0.10.0" |
| | ) |
| | from tensordict import NonTensorData, NonTensorStack |
| |
|
| | assert num_batch_dims > 0, "num_batch_dims must be greater than zero" |
| | if not all(isinstance(val, torch.Tensor) for val in tensor_dict.values()): |
| | assert num_batch_dims == 1, "only support num_batch_dims=1 when tensor_dict contains non tensor data." |
| |
|
| | if meta_info is None: |
| | meta_info = {} |
| | batch = {} |
| | non_tensor_batch = {} |
| | batch_size = None |
| | for key, val in tensor_dict.items(): |
| | if isinstance(val, torch.Tensor): |
| | batch[key] = val |
| | if batch_size is None: |
| | batch_size = val.shape[:num_batch_dims] |
| | elif isinstance(val, NonTensorStack): |
| | non_tensor_batch[key] = np.array([elem.data for elem in val], dtype=object) |
| | elif isinstance(val, NonTensorData): |
| | meta_info[key] = val.data |
| |
|
| | return cls( |
| | batch=TensorDict(batch, batch_size=batch_size), |
| | non_tensor_batch=non_tensor_batch, |
| | meta_info=meta_info, |
| | ) |
| |
|
| | def to(self, device) -> "DataProto": |
| | """move the batch to device |
| | |
| | Args: |
| | device (torch.device, str): torch device |
| | |
| | Returns: |
| | DataProto: the current DataProto |
| | |
| | """ |
| | if self.batch is not None: |
| | self.batch = self.batch.to(device) |
| | return self |
| |
|
| | def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) -> "DataProto": |
| | """Select a subset of the DataProto via batch_keys and meta_info_keys |
| | |
| | Args: |
| | batch_keys (list, optional): a list of strings indicating the keys in batch to select |
| | meta_info_keys (list, optional): a list of keys indicating the meta info to select |
| | |
| | Returns: |
| | DataProto: the DataProto with the selected batch_keys and meta_info_keys |
| | """ |
| | |
| | if batch_keys is not None: |
| | batch_keys = tuple(batch_keys) |
| | sub_batch = self.batch.select(*batch_keys) |
| | else: |
| | sub_batch = self.batch |
| |
|
| | if non_tensor_batch_keys is not None: |
| | non_tensor_batch = {key: val for key, val in self.non_tensor_batch.items() if key in non_tensor_batch_keys} |
| | else: |
| | non_tensor_batch = self.non_tensor_batch |
| |
|
| | if deepcopy: |
| | non_tensor_batch = copy.deepcopy(non_tensor_batch) |
| |
|
| | if meta_info_keys is not None: |
| | sub_meta_info = {key: val for key, val in self.meta_info.items() if key in meta_info_keys} |
| | else: |
| | sub_meta_info = self.meta_info |
| |
|
| | if deepcopy: |
| | sub_meta_info = copy.deepcopy(sub_meta_info) |
| |
|
| | return type(self)(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info) |
| |
|
| | def select_idxs(self, idxs): |
| | """ |
| | Select specific indices from the DataProto. |
| | |
| | Args: |
| | idxs (torch.Tensor or numpy.ndarray or list): Indices to select |
| | |
| | Returns: |
| | DataProto: A new DataProto containing only the selected indices |
| | """ |
| | if isinstance(idxs, list): |
| | idxs = torch.tensor(idxs) |
| | if idxs.dtype != torch.bool: |
| | idxs = idxs.type(torch.int32) |
| |
|
| | if isinstance(idxs, np.ndarray): |
| | idxs_np = idxs |
| | idxs_torch = torch.from_numpy(idxs) |
| | else: |
| | idxs_torch = idxs |
| | idxs_np = idxs.detach().cpu().numpy() |
| |
|
| | batch_size = int(idxs_np.sum()) if idxs_np.dtype == bool else idxs_np.shape[0] |
| |
|
| | if self.batch is not None: |
| | |
| | selected_batch = TensorDict( |
| | source={key: tensor[idxs_torch] for key, tensor in self.batch.items()}, |
| | batch_size=(batch_size,), |
| | device=self.batch.device, |
| | ) |
| | else: |
| | selected_batch = None |
| |
|
| | selected_non_tensor = {} |
| | for key, val in self.non_tensor_batch.items(): |
| | selected_non_tensor[key] = val[idxs_np] |
| |
|
| | return type(self)(batch=selected_batch, non_tensor_batch=selected_non_tensor, meta_info=self.meta_info) |
| |
|
| | def slice(self, start=None, end=None, step=None): |
| | """ |
| | Slice the DataProto and return a new DataProto object. |
| | This is an improved version of direct slicing which returns a DataProtoItem. |
| | |
| | Args: |
| | start (int, optional): Start index. Defaults to None (start from beginning). |
| | end (int, optional): End index (exclusive). Defaults to None (go to end). |
| | step (int, optional): Step size. Defaults to None (step=1). |
| | |
| | Returns: |
| | DataProto: A new DataProto containing the sliced data |
| | |
| | Examples: |
| | # Using the slice method directly |
| | sliced_data = data_proto.slice(10, 20) |
| | |
| | # Using enhanced indexing (returns DataProto) |
| | sliced_data = data_proto[10:20] |
| | sliced_data = data_proto[::2] # Every other element |
| | |
| | # Using list indexing (returns DataProto) |
| | indices = [1, 5, 10] |
| | selected_data = data_proto[indices] |
| | |
| | # Single index still returns DataProtoItem |
| | single_item = data_proto[5] |
| | """ |
| | |
| | slice_obj = slice(start, end, step) |
| |
|
| | |
| | if self.batch is not None: |
| | |
| | sliced_batch = self.batch[slice_obj] |
| | else: |
| | sliced_batch = None |
| |
|
| | |
| | sliced_non_tensor = {} |
| | for key, val in self.non_tensor_batch.items(): |
| | sliced_non_tensor[key] = val[slice_obj] |
| |
|
| | |
| | return type(self)(batch=sliced_batch, non_tensor_batch=sliced_non_tensor, meta_info=self.meta_info) |
| |
|
| | def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> "DataProto": |
| | """Pop a subset of the DataProto via `batch_keys` and `meta_info_keys` |
| | |
| | Args: |
| | batch_keys (list, optional): a list of strings indicating the keys in batch to pop |
| | meta_info_keys (list, optional): a list of keys indicating the meta info to pop |
| | |
| | Returns: |
| | DataProto: the DataProto with the poped batch_keys and meta_info_keys |
| | """ |
| | if batch_keys is None: |
| | batch_keys = [] |
| | if meta_info_keys is None: |
| | meta_info_keys = [] |
| | if non_tensor_batch_keys is None: |
| | non_tensor_batch_keys = [] |
| |
|
| | tensors = {} |
| | |
| | for key in batch_keys: |
| | assert key in self.batch.keys() |
| | tensors[key] = self.batch.pop(key) |
| | non_tensors = {} |
| | |
| | for key in non_tensor_batch_keys: |
| | assert key in self.non_tensor_batch.keys() |
| | non_tensors[key] = self.non_tensor_batch.pop(key) |
| | meta_info = {} |
| | for key in meta_info_keys: |
| | assert key in self.meta_info.keys() |
| | meta_info[key] = self.meta_info.pop(key) |
| | return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info) |
| |
|
| | def rename(self, old_keys=None, new_keys=None) -> "DataProto": |
| | """ |
| | Note that this function only rename the key in the batch |
| | """ |
| |
|
| | def validate_input(keys): |
| | if keys is not None: |
| | if isinstance(keys, str): |
| | keys = [keys] |
| | elif isinstance(keys, list): |
| | pass |
| | else: |
| | raise TypeError(f"keys must be a list or a string, but got {type(keys)}") |
| | return keys |
| |
|
| | old_keys = validate_input(old_keys) |
| | new_keys = validate_input(new_keys) |
| |
|
| | if len(new_keys) != len(old_keys): |
| | raise ValueError( |
| | f"new_keys and old_keys must have the same length, but got {len(new_keys)} and {len(old_keys)}" |
| | ) |
| |
|
| | self.batch.rename_key_(tuple(old_keys), tuple(new_keys)) |
| |
|
| | return self |
| |
|
| | def union(self, other: "DataProto") -> "DataProto": |
| | """Union with another DataProto. Union batch and meta_info separately. |
| | Throw an error if |
| | |
| | - there are conflict keys in batch and they are not equal |
| | - the batch size of two data batch is not the same |
| | - there are conflict keys in meta_info and they are not the same. |
| | |
| | Args: |
| | other (DataProto): another DataProto to union |
| | |
| | Returns: |
| | DataProto: the DataProto after union |
| | """ |
| | self.batch = union_tensor_dict(self.batch, other.batch) |
| | self.non_tensor_batch = union_numpy_dict(self.non_tensor_batch, other.non_tensor_batch) |
| | self.meta_info = union_two_dict(self.meta_info, other.meta_info) |
| | return self |
| |
|
| | def make_iterator(self, mini_batch_size, epochs, seed=None, dataloader_kwargs=None): |
| | r"""Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch |
| | dataset. See https://pytorch.org/tensordict/stable/tutorials/data_fashion for more details. |
| | |
| | |
| | Args: |
| | mini_batch_size (int): mini-batch size when iterating the dataset. We require that |
| | ``batch.batch_size[0] % mini_batch_size == 0``. |
| | epochs (int): number of epochs when iterating the dataset. |
| | dataloader_kwargs (Any): internally, it returns a DataLoader over the batch. The |
| | dataloader_kwargs is the kwargs passed to the DataLoader. |
| | |
| | Returns: |
| | Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration |
| | steps is ``self.batch.batch_size * epochs // mini_batch_size`` |
| | """ |
| | assert self.batch.batch_size[0] % mini_batch_size == 0, f"{self.batch.batch_size[0]} % {mini_batch_size} != 0" |
| | |
| | if dataloader_kwargs is None: |
| | dataloader_kwargs = {} |
| |
|
| | if seed is not None: |
| | generator = torch.Generator() |
| | generator.manual_seed(seed) |
| | else: |
| | generator = None |
| |
|
| | assert isinstance(dataloader_kwargs, dict) |
| | train_dataloader = DataLoader( |
| | dataset=self, batch_size=mini_batch_size, collate_fn=collate_fn, generator=generator, **dataloader_kwargs |
| | ) |
| |
|
| | def get_data(): |
| | for _ in range(epochs): |
| | for d in train_dataloader: |
| | d.meta_info = self.meta_info |
| | yield d |
| |
|
| | return iter(get_data()) |
| |
|
| | def is_padding_enabled(self): |
| | """ |
| | Check if padding is enabled for the DataProto. |
| | Returns: |
| | bool: True if padding is enabled, False otherwise. |
| | """ |
| | dataproto_specific_padding = self.meta_info.get(DataProtoConfig.auto_padding_key, False) |
| | return dataproto_specific_padding or DataProtoConfig.auto_padding |
| |
|
| | def padding(self, padding_size, padding_candidate=""): |
| | """Pad the DataProto by concating with padding_candidate.repeat(padding_size) |
| | |
| | Args: |
| | padding_size (int): the number of repeated padding_candidate |
| | padding_candidate: the item to be repeated and appended to the DataProto, only supporting ["first", "last"] |
| | """ |
| | if padding_size == 0: |
| | return |
| | padding_candidate = self.select_idxs([0 if padding_candidate == "first" else len(self) - 1]) |
| | padding_part = padding_candidate.repeat(padding_size) |
| | padded_dp = DataProto.concat([self, padding_part]) |
| | self.batch = padded_dp.batch |
| | self.non_tensor_batch = padded_dp.non_tensor_batch |
| |
|
| | def chunk(self, chunks: int) -> list["DataProto"]: |
| | """Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split. |
| | |
| | Args: |
| | chunks (int): the number of chunks to split on dim=0 |
| | |
| | Returns: |
| | List[DataProto]: a list of DataProto after splitting |
| | """ |
| | if not self.is_padding_enabled(): |
| | assert len(self) % chunks == 0, ( |
| | f"only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}." |
| | ) |
| |
|
| | bsz_in_batch = None |
| | if self.batch is not None: |
| | batch_lst = self.batch.chunk(chunks=chunks, dim=0) |
| | bsz_in_batch = np.array([batch.batch_size[0] for batch in batch_lst]) |
| | chunk_indices = np.cumsum(bsz_in_batch)[:-1] |
| | else: |
| | batch_lst = [None for _ in range(chunks)] |
| |
|
| | non_tensor_batch_lst = [{} for _ in range(chunks)] |
| | for key, val in self.non_tensor_batch.items(): |
| | assert isinstance(val, np.ndarray) |
| | if bsz_in_batch is not None: |
| | non_tensor_lst = np.array_split(val, chunk_indices.tolist()) |
| | else: |
| | non_tensor_lst = np.array_split(val, chunks) |
| | assert len(non_tensor_lst) == chunks |
| | for i in range(chunks): |
| | non_tensor_batch_lst[i][key] = non_tensor_lst[i] |
| |
|
| | output = [] |
| | for i in range(chunks): |
| | output.append( |
| | type(self)(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info) |
| | ) |
| |
|
| | return output |
| |
|
| | def split(self, split_size: int) -> list["DataProto"]: |
| | """Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split. |
| | |
| | Args: |
| | split_size (int): the size of each split |
| | |
| | Returns: |
| | List[DataProto]: a list of DataProto after splitting |
| | """ |
| | return [self[i : i + split_size] for i in range(0, len(self), split_size)] |
| |
|
| | @staticmethod |
| | def concat(data: list["DataProto"]) -> "DataProto": |
| | """Concat a list of DataProto. The batch is concatenated among dim=0. |
| | The meta_info is merged, with special handling for metrics from different workers. |
| | |
| | Args: |
| | data (List[DataProto]): list of DataProto |
| | |
| | Returns: |
| | DataProto: concatenated DataProto |
| | """ |
| | batch_lst = [] |
| | for batch in data: |
| | batch_lst.append(batch.batch) |
| | new_batch = torch.cat(batch_lst, dim=0) if batch_lst[0] is not None else None |
| |
|
| | non_tensor_batch = list_of_dict_to_dict_of_list(list_of_dict=[d.non_tensor_batch for d in data]) |
| | for key, val in non_tensor_batch.items(): |
| | non_tensor_batch[key] = np.concatenate(val, axis=0) |
| |
|
| | |
| | merged_meta_info = {} |
| | if data: |
| | |
| | all_metrics = [] |
| | for d in data: |
| | for k, v in d.meta_info.items(): |
| | if k == "metrics": |
| | if v is not None: |
| | if isinstance(v, list): |
| | all_metrics.extend(v) |
| | else: |
| | all_metrics.append(v) |
| | else: |
| | if k in merged_meta_info: |
| | |
| | assert merged_meta_info[k] == v, f"Conflicting values for meta_info key '{k}'" |
| | else: |
| | merged_meta_info[k] = v |
| |
|
| | |
| | if all_metrics: |
| | merged_meta_info["metrics"] = list_of_dict_to_dict_of_list(all_metrics) |
| |
|
| | cls = type(data[0]) if len(data) > 0 else DataProto |
| | return cls(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=merged_meta_info) |
| |
|
| | def reorder(self, indices): |
| | """ |
| | Note that this operation is in-place |
| | """ |
| | indices_np = indices.detach().numpy() |
| | self.batch = self.batch[indices] |
| | self.non_tensor_batch = {key: val[indices_np] for key, val in self.non_tensor_batch.items()} |
| |
|
| | def repeat(self, repeat_times=2, interleave=True): |
| | """ |
| | Repeat the batch data a specified number of times. |
| | |
| | Args: |
| | repeat_times (int): Number of times to repeat the data. |
| | interleave (bool): Whether to interleave the repeated data. |
| | |
| | Returns: |
| | DataProto: A new DataProto with repeated data. |
| | """ |
| | if self.batch is not None: |
| | if interleave: |
| | |
| | repeated_tensors = { |
| | key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items() |
| | } |
| | else: |
| | |
| | repeated_tensors = { |
| | key: tensor.unsqueeze(0).expand(repeat_times, *tensor.shape).reshape(-1, *tensor.shape[1:]) |
| | for key, tensor in self.batch.items() |
| | } |
| |
|
| | repeated_batch = TensorDict( |
| | source=repeated_tensors, |
| | batch_size=(self.batch.batch_size[0] * repeat_times,), |
| | ) |
| | else: |
| | repeated_batch = None |
| |
|
| | repeated_non_tensor_batch = {} |
| | for key, val in self.non_tensor_batch.items(): |
| | if interleave: |
| | repeated_non_tensor_batch[key] = np.repeat(val, repeat_times, axis=0) |
| | else: |
| | repeated_non_tensor_batch[key] = np.tile(val, (repeat_times,) + (1,) * (val.ndim - 1)) |
| |
|
| | return type(self)( |
| | batch=repeated_batch, |
| | non_tensor_batch=repeated_non_tensor_batch, |
| | meta_info=self.meta_info, |
| | ) |
| |
|
| | def unfold_column_chunks(self, n_split: int, split_keys: Optional[list[str]] = None): |
| | """Split along the second dim into `n_split`, unfold it to the first dim (batch dim) |
| | Useful in passing grouped tensors that doesn't want to be shuffled in dataset. |
| | keys not in split_keys are repeated to match the shape |
| | Note that if the `split_keys` is not provided, it will repeat all the keys in the second dim. |
| | """ |
| | if self.batch is not None: |
| | unfolded_batch = {} |
| | for key in self.batch.keys(): |
| | if key in split_keys if split_keys is not None else False: |
| | shape = list(self.batch[key].shape) |
| | shape[0] = self.batch[key].shape[0] * n_split |
| | shape[1] = self.batch[key].shape[1] // n_split |
| | unfolded_batch[key] = self.batch[key].reshape(*shape) |
| | else: |
| | unfolded_batch[key] = torch.repeat_interleave(self.batch[key], n_split, dim=0) |
| | |
| | unfolded_batch = TensorDict( |
| | source=unfolded_batch, batch_size=(self.batch.batch_size[0] * n_split,), device=self.batch.device |
| | ) |
| | else: |
| | unfolded_batch = None |
| |
|
| | repeated_non_tensor_batch = {} |
| | for key, val in self.non_tensor_batch.items(): |
| | if key in split_keys: |
| | shape = list(val.shape) |
| | shape[0] = val.shape[0] * n_split |
| | shape[1] = val.shape[1] // n_split |
| | repeated_non_tensor_batch[key] = val.reshape(*shape) |
| | else: |
| | repeated_non_tensor_batch[key] = np.repeat(val, n_split, axis=0) |
| |
|
| | return type(self)( |
| | batch=unfolded_batch, |
| | non_tensor_batch=repeated_non_tensor_batch, |
| | meta_info=self.meta_info, |
| | ) |
| |
|
| | def sample_level_repeat(self, repeat_times): |
| | """ |
| | Repeat each row of the batch data a specified number of times. |
| | |
| | Args: |
| | repeat_times (torch.tensor, list, tuple, ndarray): Number of times to repeat the data. |
| | |
| | Returns: |
| | DataProto: A new DataProto with repeated data. |
| | """ |
| | if isinstance(repeat_times, tuple): |
| | repeat_times = list(repeat_times) |
| | elif isinstance(repeat_times, torch.Tensor): |
| | assert len(repeat_times.shape) == 1 |
| | repeat_times = repeat_times.tolist() |
| | elif isinstance(repeat_times, np.ndarray): |
| | assert len(repeat_times.shape) == 1 |
| | repeat_times = repeat_times.tolist() |
| | else: |
| | assert isinstance(repeat_times, list), ( |
| | f"repeat_times type must be in [list, torch.Tensor, np.ndarray, tuple], got {type(repeat_times)}" |
| | ) |
| | repeat_times = torch.tensor(repeat_times) |
| |
|
| | if self.batch is not None: |
| | |
| | repeated_tensors = { |
| | key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items() |
| | } |
| |
|
| | repeated_batch = TensorDict( |
| | source=repeated_tensors, |
| | batch_size=(repeat_times.sum().item(),), |
| | device=self.batch.device, |
| | ) |
| | else: |
| | repeated_batch = None |
| |
|
| | repeated_non_tensor_batch = {} |
| | for key, val in self.non_tensor_batch.items(): |
| | repeated_non_tensor_batch[key] = np.repeat(val, repeat_times, axis=0) |
| |
|
| | return type(self)( |
| | batch=repeated_batch, |
| | non_tensor_batch=repeated_non_tensor_batch, |
| | meta_info=self.meta_info, |
| | ) |
| |
|
| | def to_tensordict(self) -> TensorDict: |
| | """Convert this DataProto to TensorDict. Note that this requires tensordict version at least 0.10 |
| | |
| | Returns: |
| | |
| | """ |
| | assert parse_version(tensordict.__version__) >= parse_version("0.10"), ( |
| | "Convert DataProto to TensorDict at least requires tensordict version 0.10" |
| | ) |
| | tensor_batch = self.batch.to_dict() |
| | non_tensor_batch = self.non_tensor_batch |
| |
|
| | from tensordict.tensorclass import NonTensorData, NonTensorStack |
| |
|
| | from verl.utils import tensordict_utils as tu |
| |
|
| | common_keys = set(tensor_batch.keys()) & set(non_tensor_batch.keys()) |
| | assert len(common_keys) == 0, f"tensor_batch and non_tensor_batch have common keys {common_keys}" |
| |
|
| | for key, val in non_tensor_batch.items(): |
| | assert isinstance(val, np.ndarray) |
| | |
| | tensor_batch[key] = NonTensorStack.from_list([NonTensorData(item) for item in val]) |
| | output = tu.get_tensordict(tensor_dict=tensor_batch, non_tensor_dict=self.meta_info) |
| | return output |
| |
|
| | def get_data_info(self) -> str: |
| | """Return formatted information about stored data with nested type details. |
| | |
| | Returns: |
| | str: Formatted string showing tensor details and recursive metadata types |
| | """ |
| | info = ["batch"] |
| |
|
| | for key, tensor in self.batch.items(): |
| | if hasattr(tensor, "shape") and hasattr(tensor, "dtype") and hasattr(tensor, "device"): |
| | info.append(f" {key}: {tuple(tensor.shape)} ({tensor.dtype}) {tensor.device}") |
| | elif hasattr(tensor, "shape") and hasattr(tensor, "dtype"): |
| | info.append(f" {key}: {tuple(tensor.shape)} ({tensor.dtype})") |
| | else: |
| | info.append(f" {key}: {type(tensor).__name__}") |
| |
|
| | info.append("non_tensor_batch") |
| | for key, array in self.non_tensor_batch.items(): |
| | info.append(f" {key}: ndarray{array.shape} ({array.dtype})") |
| |
|
| | info.append("meta_info") |
| | for k, v in self.meta_info.items(): |
| | type_info = self._get_type_info(v) |
| | info.append(f" {k}: {type_info}") |
| |
|
| | return "\n".join(info) |
| |
|
| | def _get_type_info(self, value): |
| | """Recursively get type information for nested structures""" |
| | if isinstance(value, list): |
| | elem_types = {self._get_type_info(v) for v in value[:3]} |
| | return f"list[{'|'.join(elem_types) if elem_types else '...'}]" |
| | if isinstance(value, tuple): |
| | elem_types = [self._get_type_info(v) for v in value] |
| | return f"tuple({', '.join(elem_types)})" |
| | if isinstance(value, dict): |
| | if not value: |
| | return "dict" |
| | k, v = next(iter(value.items())) |
| | return f"dict[{self._get_type_info(k)}: {self._get_type_info(v)}]" |
| | if isinstance(value, np.ndarray): |
| | return f"ndarray{value.shape} ({value.dtype})" |
| | return type(value).__name__ |
| |
|
| |
|
| | @dataclass |
| | class DataProtoFuture: |
| | """ |
| | DataProtoFuture aims to eliminate actual data fetching on driver. By doing so, the driver doesn't have to wait |
| | for data so that asynchronous execution becomes possible. |
| | DataProtoFuture contains a list of futures from another WorkerGroup of size world_size. |
| | - collect_fn is a Callable that reduces the list of futures to a DataProto |
| | - dispatch_fn is a Callable that partitions the DataProto into a list of DataProto of size world_size |
| | and then select |
| | |
| | Potential issue: we can optimize dispatch_fn(collect_fn) such that only needed data is fetched on destination |
| | - DataProtoFuture only supports directly passing from the output of a method to another input. You can't perform any |
| | operation on the DataProtoFuture in driver. |
| | """ |
| |
|
| | collect_fn: Callable |
| | futures: list[ray.ObjectRef] |
| | dispatch_fn: Callable = None |
| |
|
| | @staticmethod |
| | def concat(data: list[ray.ObjectRef]) -> "DataProtoFuture": |
| | output = DataProtoFuture(collect_fn=DataProto.concat, futures=data) |
| | return output |
| |
|
| | def chunk(self, chunks: int) -> list["DataProtoFuture"]: |
| | from functools import partial |
| |
|
| | arg_future_lst = [] |
| | for i in range(chunks): |
| | |
| | def dispatch_fn(x, i, chunks): |
| | return x.chunk(chunks=chunks)[i] |
| |
|
| | arg_future = DataProtoFuture( |
| | collect_fn=self.collect_fn, dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks), futures=self.futures |
| | ) |
| | arg_future_lst.append(arg_future) |
| | return arg_future_lst |
| |
|
| | def get(self): |
| | output = ray.get(self.futures) |
| | for o in output: |
| | assert isinstance(o, DataProto | TensorDict) |
| |
|
| | if isinstance(output[0], DataProto): |
| | output = DataProto.concat(output) |
| | elif isinstance(output[0], TensorDict): |
| | from verl.utils.tensordict_utils import concat_tensordict |
| |
|
| | output = concat_tensordict(output) |
| | else: |
| | raise TypeError(f"Unknown type {type(o[0])} in DataProtoFuture") |
| |
|
| | if self.dispatch_fn is not None: |
| | output = self.dispatch_fn(output) |
| | return output |
| |
|
| |
|
| | def all_gather_data_proto(data: DataProto, process_group): |
| | |
| | group_size = torch.distributed.get_world_size(group=process_group) |
| | assert isinstance(data, DataProto) |
| | prev_device = data.batch.device |
| | data = data.to(get_device_id()) |
| | data.batch = allgather_dict_tensors(data.batch.contiguous(), size=group_size, group=process_group, dim=0) |
| | data = data.to(prev_device) |
| | |
| | all_non_tensor_batch = [None for _ in range(group_size)] |
| | torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=process_group) |
| | data.non_tensor_batch = {k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch} |
| |
|