|
|
import functools |
|
|
from collections import namedtuple |
|
|
|
|
|
from typing import Callable, Iterator, Sized, TypeVar, Optional, Union, Any, Dict, List |
|
|
|
|
|
from torch.utils.data.datapipes._decorator import functional_datapipe |
|
|
from torch.utils.data._utils.collate import default_collate |
|
|
from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper |
|
|
from torch.utils.data.datapipes.datapipe import IterDataPipe |
|
|
from torch.utils.data.datapipes.utils.common import (_check_unpickable_fn, |
|
|
validate_input_col) |
|
|
|
|
|
__all__ = [ |
|
|
"CollatorIterDataPipe", |
|
|
"MapperIterDataPipe", |
|
|
] |
|
|
|
|
|
T_co = TypeVar("T_co", covariant=True) |
|
|
|
|
|
|
|
|
@functional_datapipe("map") |
|
|
class MapperIterDataPipe(IterDataPipe[T_co]): |
|
|
r""" |
|
|
Applies a function over each item from the source DataPipe (functional name: ``map``). |
|
|
The function can be any regular Python function or partial object. Lambda |
|
|
function is not recommended as it is not supported by pickle. |
|
|
|
|
|
Args: |
|
|
datapipe: Source Iterable DataPipe |
|
|
fn: Function being applied over each item |
|
|
input_col: Index or indices of data which ``fn`` is applied, such as: |
|
|
|
|
|
- ``None`` as default to apply ``fn`` to the data directly. |
|
|
- Integer(s) is used for list/tuple. |
|
|
- Key(s) is used for dict. |
|
|
|
|
|
output_col: Index of data where result of ``fn`` is placed. ``output_col`` can be specified |
|
|
only when ``input_col`` is not ``None`` |
|
|
|
|
|
- ``None`` as default to replace the index that ``input_col`` specified; For ``input_col`` with |
|
|
multiple indices, the left-most one is used, and other indices will be removed. |
|
|
- Integer is used for list/tuple. ``-1`` represents to append result at the end. |
|
|
- Key is used for dict. New key is acceptable. |
|
|
|
|
|
Example: |
|
|
>>> # xdoctest: +SKIP |
|
|
>>> from torchdata.datapipes.iter import IterableWrapper, Mapper |
|
|
>>> def add_one(x): |
|
|
... return x + 1 |
|
|
>>> dp = IterableWrapper(range(10)) |
|
|
>>> map_dp_1 = dp.map(add_one) # Invocation via functional form is preferred |
|
|
>>> list(map_dp_1) |
|
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10] |
|
|
>>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle` |
|
|
>>> # Use `functools.partial` or explicitly define the function instead |
|
|
>>> map_dp_2 = Mapper(dp, lambda x: x + 1) |
|
|
>>> list(map_dp_2) |
|
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10] |
|
|
""" |
|
|
datapipe: IterDataPipe |
|
|
fn: Callable |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
datapipe: IterDataPipe, |
|
|
fn: Callable, |
|
|
input_col=None, |
|
|
output_col=None, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.datapipe = datapipe |
|
|
|
|
|
_check_unpickable_fn(fn) |
|
|
self.fn = fn |
|
|
|
|
|
self.input_col = input_col |
|
|
if input_col is None and output_col is not None: |
|
|
raise ValueError("`output_col` must be None when `input_col` is None.") |
|
|
if isinstance(output_col, (list, tuple)): |
|
|
if len(output_col) > 1: |
|
|
raise ValueError("`output_col` must be a single-element list or tuple") |
|
|
output_col = output_col[0] |
|
|
self.output_col = output_col |
|
|
validate_input_col(fn, input_col) |
|
|
|
|
|
def _apply_fn(self, data): |
|
|
if self.input_col is None and self.output_col is None: |
|
|
return self.fn(data) |
|
|
|
|
|
if self.input_col is None: |
|
|
res = self.fn(data) |
|
|
elif isinstance(self.input_col, (list, tuple)): |
|
|
args = tuple(data[col] for col in self.input_col) |
|
|
res = self.fn(*args) |
|
|
else: |
|
|
res = self.fn(data[self.input_col]) |
|
|
|
|
|
|
|
|
if isinstance(data, tuple): |
|
|
t_flag = True |
|
|
data = list(data) |
|
|
else: |
|
|
t_flag = False |
|
|
|
|
|
if self.output_col is None: |
|
|
if isinstance(self.input_col, (list, tuple)): |
|
|
data[self.input_col[0]] = res |
|
|
for idx in sorted(self.input_col[1:], reverse=True): |
|
|
del data[idx] |
|
|
else: |
|
|
data[self.input_col] = res |
|
|
else: |
|
|
if self.output_col == -1: |
|
|
data.append(res) |
|
|
else: |
|
|
data[self.output_col] = res |
|
|
|
|
|
|
|
|
return tuple(data) if t_flag else data |
|
|
|
|
|
def __iter__(self) -> Iterator[T_co]: |
|
|
for data in self.datapipe: |
|
|
yield self._apply_fn(data) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
if isinstance(self.datapipe, Sized): |
|
|
return len(self.datapipe) |
|
|
raise TypeError( |
|
|
"{} instance doesn't have valid length".format(type(self).__name__) |
|
|
) |
|
|
|
|
|
|
|
|
def _collate_helper(conversion, item): |
|
|
|
|
|
if len(item.items) > 1: |
|
|
|
|
|
raise Exception("Only supports one DataFrame per batch") |
|
|
df = item[0] |
|
|
columns_name = df_wrapper.get_columns(df) |
|
|
tuple_names: List = [] |
|
|
tuple_values: List = [] |
|
|
|
|
|
for name in conversion.keys(): |
|
|
if name not in columns_name: |
|
|
raise Exception("Conversion keys missmatch") |
|
|
|
|
|
for name in columns_name: |
|
|
if name in conversion: |
|
|
if not callable(conversion[name]): |
|
|
raise Exception('Collate (DF)DataPipe requires callable as dict values') |
|
|
collation_fn = conversion[name] |
|
|
else: |
|
|
|
|
|
try: |
|
|
import torcharrow.pytorch as tap |
|
|
collation_fn = tap.rec.Default() |
|
|
except Exception: |
|
|
raise Exception("unable to import default collation function from the TorchArrrow") |
|
|
|
|
|
tuple_names.append(str(name)) |
|
|
value = collation_fn(df[name]) |
|
|
tuple_values.append(value) |
|
|
|
|
|
|
|
|
|
|
|
tpl_cls = namedtuple("CollateResult", tuple_names) |
|
|
tuple = tpl_cls(*tuple_values) |
|
|
return tuple |
|
|
|
|
|
|
|
|
@functional_datapipe("collate") |
|
|
class CollatorIterDataPipe(MapperIterDataPipe): |
|
|
r""" |
|
|
Collates samples from DataPipe to Tensor(s) by a custom collate function (functional name: ``collate``). |
|
|
By default, it uses :func:`torch.utils.data.default_collate`. |
|
|
|
|
|
.. note:: |
|
|
While writing a custom collate function, you can import :func:`torch.utils.data.default_collate` for the |
|
|
default behavior and `functools.partial` to specify any additional arguments. |
|
|
|
|
|
Args: |
|
|
datapipe: Iterable DataPipe being collated |
|
|
collate_fn: Customized collate function to collect and combine data or a batch of data. |
|
|
Default function collates to Tensor(s) based on data type. |
|
|
|
|
|
Example: Convert integer data to float Tensor |
|
|
>>> class MyIterDataPipe(torch.utils.data.IterDataPipe): |
|
|
... def __init__(self, start, end): |
|
|
... super(MyIterDataPipe).__init__() |
|
|
... assert end > start, "this example code only works with end >= start" |
|
|
... self.start = start |
|
|
... self.end = end |
|
|
... |
|
|
... def __iter__(self): |
|
|
... return iter(range(self.start, self.end)) |
|
|
... |
|
|
... def __len__(self): |
|
|
... return self.end - self.start |
|
|
... |
|
|
>>> ds = MyIterDataPipe(start=3, end=7) |
|
|
>>> print(list(ds)) |
|
|
[3, 4, 5, 6] |
|
|
>>> def collate_fn(batch): |
|
|
... return torch.tensor(batch, dtype=torch.float) |
|
|
... |
|
|
>>> # xdoctest: +SKIP |
|
|
>>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn) |
|
|
>>> print(list(collated_ds)) |
|
|
[tensor(3.), tensor(4.), tensor(5.), tensor(6.)] |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
datapipe: IterDataPipe, |
|
|
conversion: Optional[ |
|
|
Union[ |
|
|
Callable[..., Any], |
|
|
Dict[Union[str, Any], Union[Callable, Any]], |
|
|
] |
|
|
] = default_collate, |
|
|
collate_fn: Optional[Callable] = None, |
|
|
) -> None: |
|
|
|
|
|
|
|
|
if collate_fn is not None: |
|
|
super().__init__(datapipe, fn=collate_fn) |
|
|
else: |
|
|
if callable(conversion): |
|
|
super().__init__(datapipe, fn=conversion) |
|
|
else: |
|
|
|
|
|
collate_fn = functools.partial(_collate_helper, conversion) |
|
|
super().__init__(datapipe, fn=collate_fn) |
|
|
|