| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Implement base data transfer protocol between any two functions, modules. |
| We can subclass Protocol to define more detailed batch info with specific keys |
| """ |
|
|
| import numpy as np |
| import copy |
| from dataclasses import dataclass, field |
| from typing import Callable, Dict, List, Union |
|
|
| import torch |
| import tensordict |
| from tensordict import TensorDict |
| from torch.utils.data import DataLoader, Dataset |
|
|
| from verl.utils.py_functional import union_two_dict |
|
|
| __all__ = ['DataProto', 'union_tensor_dict',] |
|
|
| try: |
| tensordict.set_lazy_legacy(False).set() |
| except: |
| pass |
|
|
|
|
| def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict: |
| """Union two tensordicts.""" |
| assert tensor_dict1.batch_size == tensor_dict2.batch_size, \ |
| f'Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}' |
| for key in tensor_dict2.keys(): |
| if key not in tensor_dict1.keys(): |
| tensor_dict1[key] = tensor_dict2[key] |
| else: |
| assert tensor_dict1[key].equal(tensor_dict2[key]), \ |
| f'{key} in tensor_dict1 and tensor_dict2 are not the same object' |
|
|
| return tensor_dict1 |
|
|
|
|
| def union_numpy_dict(tensor_dict1: dict[np.ndarray], tensor_dict2: dict[np.ndarray]) -> dict[np.ndarray]: |
| for key, val in tensor_dict2.items(): |
| if key in tensor_dict1: |
| assert isinstance(tensor_dict2[key], np.ndarray) |
| assert isinstance(tensor_dict1[key], np.ndarray) |
| assert np.all(tensor_dict2[key] == tensor_dict1[key]), \ |
| f'{key} in tensor_dict1 and tensor_dict2 are not the same object' |
| tensor_dict1[key] = val |
|
|
| return tensor_dict1 |
|
|
|
|
| def list_of_dict_to_dict_of_list(list_of_dict: list[dict]): |
| if len(list_of_dict) == 0: |
| return {} |
| keys = list_of_dict[0].keys() |
| output = {key: [] for key in keys} |
| for data in list_of_dict: |
| for key, item in data.items(): |
| assert key in output |
| output[key].append(item) |
| return output |
|
|
|
|
| def collate_fn(x: list['DataProtoItem']): |
| batch = [] |
| non_tensor_batch = [] |
| for data in x: |
| batch.append(data.batch) |
| non_tensor_batch.append(data.non_tensor_batch) |
| batch = torch.stack(batch).contiguous() |
| non_tensor_batch = list_of_dict_to_dict_of_list(non_tensor_batch) |
| for key, val in non_tensor_batch.items(): |
| non_tensor_batch[key] = np.array(val, dtype=object) |
| return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) |
|
|
|
|
| @dataclass |
| class DataProtoItem: |
| |
| batch: TensorDict = None |
| non_tensor_batch: Dict = field(default_factory=dict) |
| meta_info: Dict = field(default_factory=dict) |
|
|
|
|
| @dataclass |
| class DataProto: |
| """ |
| A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions. |
| It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/. |
| TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the |
| same batch size should be put inside batch. |
| """ |
| batch: TensorDict = None |
| non_tensor_batch: Dict = field(default_factory=dict) |
| meta_info: Dict = field(default_factory=dict) |
|
|
| def __post_init__(self): |
| |
| self.check_consistency() |
|
|
| def __len__(self): |
| return self.batch.batch_size[0] |
|
|
| def __getitem__(self, item): |
| tensor_data = self.batch[item] |
| non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()} |
| return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info) |
|
|
| def slice(self, index): |
| tensor_data = self.batch[index] |
| non_tensor_data = {key: val[index] for key, val in self.non_tensor_batch.items()} |
| return DataProto(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info) |
|
|
| def slice_batch(self, start, length, dim=0): |
| """ |
| Note that this operation is in-place |
| """ |
| for key, val in self.batch.items(): |
| self.batch[key] = val.narrow(start=start, length=length, dim=dim) |
|
|
|
|
| def __getstate__(self): |
| import io |
| buffer = io.BytesIO() |
| if tensordict.__version__ >= '0.5.0' and self.batch is not None: |
| self.batch = self.batch.contiguous() |
| self.batch = self.batch.consolidate() |
| torch.save(self.batch, buffer) |
| return buffer, self.non_tensor_batch, self.meta_info |
|
|
| def __setstate__(self, data): |
| batch_deserialized, non_tensor_batch, meta_info = data |
| batch_deserialized.seek(0) |
| batch = torch.load(batch_deserialized, |
| weights_only=False, |
| map_location='cpu' if not torch.cuda.is_available() else None) |
| self.batch = batch |
| self.non_tensor_batch = non_tensor_batch |
| self.meta_info = meta_info |
|
|
| def check_consistency(self): |
| """Check the consistency of the DataProto. Mainly for batch and non_tensor_batch |
| We expose this function as a public one so that user can call themselves directly |
| """ |
| if self.batch is not None: |
| assert len(self.batch.batch_size) == 1, 'only support num_batch_dims=1' |
|
|
| if len(self.non_tensor_batch) != 0: |
| |
| assert len(self.batch.batch_size) == 1, 'only support num_batch_dims=1 when non_tensor_batch is not empty.' |
|
|
| batch_size = self.batch.batch_size[0] |
| for key, val in self.non_tensor_batch.items(): |
| assert isinstance( |
| val, np.ndarray |
| ) and val.dtype == object, 'data in the non_tensor_batch must be a numpy.array with dtype=object' |
| assert val.shape[ |
| 0] == batch_size, f'key {key} length {len(val)} is not equal to batch size {batch_size}' |
|
|
| @classmethod |
| def from_single_dict(cls, data: Dict[str, Union[torch.Tensor, np.ndarray]], meta_info=None): |
| tensors = {} |
| non_tensors = {} |
|
|
| for key, val in data.items(): |
| if isinstance(val, torch.Tensor): |
| tensors[key] = val |
| elif isinstance(val, np.ndarray): |
| non_tensors[key] = val |
| else: |
| raise ValueError(f'Unsupported type in data {type(val)}') |
|
|
| return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info) |
|
|
| @classmethod |
| def from_dict(cls, tensors: Dict[str, torch.Tensor], non_tensors=None, meta_info=None, num_batch_dims=1): |
| """Create a DataProto from a dict of tensors. This assumes that |
| 1. All the tensor in tensors have the same dim0 |
| 2. Only dim0 is the batch dim |
| """ |
| assert len(tensors) > 0, 'tensors must not be empty' |
| assert num_batch_dims > 0, 'num_batch_dims must be greater than zero' |
| if non_tensors is not None: |
| assert num_batch_dims == 1, 'only support num_batch_dims=1 when non_tensors is not None.' |
|
|
| if meta_info is None: |
| meta_info = {} |
| if non_tensors is None: |
| non_tensors = {} |
|
|
| assert isinstance(non_tensors, dict) |
|
|
| |
| batch_size = None |
| pivot_key = None |
| for key, tensor in tensors.items(): |
| if batch_size is None: |
| batch_size = tensor.shape[:num_batch_dims] |
| pivot_key = key |
| else: |
| current_batch = tensor.shape[:num_batch_dims] |
| assert batch_size == current_batch, \ |
| f'Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. Got {pivot_key} has {batch_size}, {key} has {current_batch}' |
|
|
| for key, val in non_tensors.items(): |
| non_tensors[key] = np.array(val, dtype=object) |
|
|
| tensor_dict = TensorDict(source=tensors, batch_size=batch_size) |
| return cls(batch=tensor_dict, non_tensor_batch=non_tensors, meta_info=meta_info) |
|
|
| def to(self, device) -> 'DataProto': |
| """move the batch to device |
| |
| Args: |
| device (torch.device, str): torch device |
| |
| Returns: |
| DataProto: the current DataProto |
| |
| """ |
| if self.batch is not None: |
| self.batch = self.batch.to(device) |
| return self |
|
|
| def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) -> 'DataProto': |
| """Select a subset of the DataProto via batch_keys and meta_info_keys |
| |
| Args: |
| batch_keys (list, optional): a list of strings indicating the keys in batch to select |
| meta_info_keys (list, optional): a list of keys indicating the meta info to select |
| |
| Returns: |
| DataProto: the DataProto with the selected batch_keys and meta_info_keys |
| """ |
| |
| if batch_keys is not None: |
| batch_keys = tuple(batch_keys) |
| sub_batch = self.batch.select(*batch_keys) |
| else: |
| sub_batch = self.batch |
|
|
| if non_tensor_batch_keys is not None: |
| non_tensor_batch = {key: val for key, val in self.non_tensor_batch.items() if key in non_tensor_batch_keys} |
| else: |
| non_tensor_batch = self.non_tensor_batch |
|
|
| if deepcopy: |
| non_tensor_batch = copy.deepcopy(non_tensor_batch) |
|
|
| if meta_info_keys is not None: |
| sub_meta_info = {key: val for key, val in self.meta_info.items() if key in meta_info_keys} |
| else: |
| sub_meta_info = self.meta_info |
|
|
| if deepcopy: |
| sub_meta_info = copy.deepcopy(sub_meta_info) |
|
|
| return DataProto(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info) |
|
|
| def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> 'DataProto': |
| """Pop a subset of the DataProto via `batch_keys` and `meta_info_keys` |
| |
| Args: |
| batch_keys (list, optional): a list of strings indicating the keys in batch to pop |
| meta_info_keys (list, optional): a list of keys indicating the meta info to pop |
| |
| Returns: |
| DataProto: the DataProto with the poped batch_keys and meta_info_keys |
| """ |
| assert batch_keys is not None |
| if meta_info_keys is None: |
| meta_info_keys = [] |
| if non_tensor_batch_keys is None: |
| non_tensor_batch_keys = [] |
|
|
| tensors = {} |
| |
| for key in batch_keys: |
| assert key in self.batch.keys() |
| tensors[key] = self.batch.pop(key) |
| non_tensors = {} |
| |
| for key in non_tensor_batch_keys: |
| assert key in self.non_tensor_batch.keys() |
| non_tensors[key] = self.non_tensor_batch.pop(key) |
| meta_info = {} |
| for key in meta_info_keys: |
| assert key in self.meta_info.keys() |
| meta_info[key] = self.meta_info.pop(key) |
| return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info) |
|
|
| def rename(self, old_keys=None, new_keys=None) -> 'DataProto': |
| """ |
| Note that this function only rename the key in the batch |
| """ |
|
|
| def validate_input(keys): |
| if keys is not None: |
| if isinstance(keys, str): |
| keys = [keys] |
| elif isinstance(keys, list): |
| pass |
| else: |
| raise TypeError(f'keys must be a list or a string, but got {type(keys)}') |
| return keys |
|
|
| old_keys = validate_input(old_keys) |
| new_keys = validate_input(new_keys) |
|
|
| if len(new_keys) != len(old_keys): |
| raise ValueError( |
| f'new_keys and old_keys must have the same length, but got {len(new_keys)} and {len(old_keys)}') |
|
|
| self.batch.rename_key_(tuple(old_keys), tuple(new_keys)) |
|
|
| return self |
|
|
| def union(self, other: 'DataProto') -> 'DataProto': |
| """Union with another DataProto. Union batch and meta_info separately. |
| Throw an error if |
| - there are conflict keys in batch and they are not equal |
| - the batch size of two data batch is not the same |
| - there are conflict keys in meta_info and they are not the same. |
| |
| Args: |
| other (DataProto): another DataProto to union |
| |
| Returns: |
| DataProto: the DataProto after union |
| """ |
| self.batch = union_tensor_dict(self.batch, other.batch) |
| self.non_tensor_batch = union_numpy_dict(self.non_tensor_batch, other.non_tensor_batch) |
| self.meta_info = union_two_dict(self.meta_info, other.meta_info) |
| return self |
|
|
| def make_iterator(self, mini_batch_size, epochs, seed=None, dataloader_kwargs=None): |
| """Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch |
| dataset. See https://pytorch.org/tensordict/tutorials/data_fashion for more details. |
| |
| Args: |
| mini_batch_size (int): mini-batch size when iterating the dataset. We require that |
| ``batch.batch_size[0] % mini_batch_size == 0`` |
| epochs (int): number of epochs when iterating the dataset. |
| dataloader_kwargs: internally, it returns a DataLoader over the batch. |
| The dataloader_kwargs is the kwargs passed to the DataLoader |
| |
| Returns: |
| Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration steps is |
| ``self.batch.batch_size * epochs // mini_batch_size`` |
| """ |
| assert self.batch.batch_size[0] % mini_batch_size == 0, f"{self.batch.batch_size[0]} % {mini_batch_size} != 0" |
| |
| if dataloader_kwargs is None: |
| dataloader_kwargs = {} |
|
|
| if seed is not None: |
| generator = torch.Generator() |
| generator.manual_seed(seed) |
| else: |
| generator = None |
|
|
| assert isinstance(dataloader_kwargs, Dict) |
| train_dataloader = DataLoader(dataset=self, |
| batch_size=mini_batch_size, |
| collate_fn=collate_fn, |
| generator=generator, |
| **dataloader_kwargs) |
|
|
| def get_data(): |
| for _ in range(epochs): |
| for d in train_dataloader: |
| d.meta_info = self.meta_info |
| yield d |
|
|
| return iter(get_data()) |
|
|
| def chunk(self, chunks: int) -> List['DataProto']: |
| """Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split. |
| |
| Args: |
| chunks (int): the number of chunks to split on dim=0 |
| |
| Returns: |
| List[DataProto]: a list of DataProto after splitting |
| """ |
| if self.batch is not None: |
| batch_lst = self.batch.chunk(chunks=chunks, dim=0) |
| else: |
| batch_lst = [None for _ in range(chunks)] |
|
|
| non_tensor_batch_lst = [{} for _ in range(chunks)] |
| for key, val in self.non_tensor_batch.items(): |
| assert isinstance(val, np.ndarray) |
| non_tensor_lst = np.array_split(val, chunks) |
| assert len(non_tensor_lst) == chunks |
| for i in range(chunks): |
| non_tensor_batch_lst[i][key] = non_tensor_lst[i] |
|
|
| output = [] |
| assert len(batch_lst) == chunks, (f"len(batch_lst) ({len(batch_lst)}) != chunks ({chunks})") |
| assert len(non_tensor_batch_lst) == chunks, (f"len(non_tensor_batch_lst) ({len(non_tensor_batch_lst)}) != chunks ({chunks})") |
| for i in range(chunks): |
| output.append( |
| DataProto(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info)) |
|
|
| return output |
|
|
| @staticmethod |
| def concat(data: List['DataProto']) -> 'DataProto': |
| """Concat a list of DataProto. The batch is concatenated among dim=0. |
| The meta_info is assumed to be identical and will use the first one. |
| |
| Args: |
| data (List[DataProto]): list of DataProto |
| |
| Returns: |
| DataProto: concatenated DataProto |
| """ |
| batch_lst = [] |
| for batch in data: |
| batch_lst.append(batch.batch) |
| if batch_lst[0] is not None: |
| new_batch = torch.cat(batch_lst, dim=0) |
| else: |
| new_batch = None |
|
|
| non_tensor_batch = list_of_dict_to_dict_of_list(list_of_dict=[d.non_tensor_batch for d in data]) |
| for key, val in non_tensor_batch.items(): |
| non_tensor_batch[key] = np.concatenate(val, axis=0) |
|
|
| return DataProto(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info) |
|
|
| def reorder(self, indices): |
| """ |
| Note that this operation is in-place |
| """ |
| indices_np = indices.detach().numpy() |
| self.batch = self.batch[indices] |
| self.non_tensor_batch = {key: val[indices_np] for key, val in self.non_tensor_batch.items()} |
|
|
|
|
| import ray |
|
|
|
|
| @dataclass |
| class DataProtoFuture: |
| """ |
| DataProtoFuture aims to eliminate actual data fetching on driver. By doing so, the driver doesn't have to wait |
| for data so that asynchronous execution becomes possible. |
| DataProtoFuture contains a list of futures from another WorkerGroup of size world_size. |
| - collect_fn is a Callable that reduces the list of futures to a DataProto |
| - dispatch_fn is a Callable that partitions the DataProto into a list of DataProto of size world_size and then select |
| |
| Potential issue: we can optimize dispatch_fn(collect_fn) such that only needed data is fetched on destination |
| - DataProtoFuture only supports directly passing from the output of a method to another input. You can't perform any |
| operation on the DataProtoFuture in driver. |
| """ |
| collect_fn: Callable |
| futures: List[ray.ObjectRef] |
| dispatch_fn: Callable = None |
|
|
| @staticmethod |
| def concat(data: List[ray.ObjectRef]) -> 'DataProtoFuture': |
| output = DataProtoFuture(collect_fn=DataProto.concat, futures=data) |
| return output |
|
|
| def chunk(self, chunks: int) -> List['DataProtoFuture']: |
| from functools import partial |
|
|
| arg_future_lst = [] |
| for i in range(chunks): |
| |
| def dispatch_fn(x, i, chunks): |
| return x.chunk(chunks=chunks)[i] |
|
|
| arg_future = DataProtoFuture(collect_fn=self.collect_fn, |
| dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks), |
| futures=self.futures) |
| arg_future_lst.append(arg_future) |
| return arg_future_lst |
|
|
| def get(self): |
| output = ray.get(self.futures) |
| for o in output: |
| assert isinstance(o, DataProto) |
| output = self.collect_fn(output) |
| if self.dispatch_fn is not None: |
| output = self.dispatch_fn(output) |
| return output |
|
|