|
|
from typing import Callable, Iterator, Optional, TypeVar |
|
|
|
|
|
from torch.utils.data.datapipes._decorator import functional_datapipe |
|
|
from torch.utils.data.datapipes.datapipe import IterDataPipe |
|
|
from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper |
|
|
from torch.utils.data.datapipes.utils.common import ( |
|
|
_check_unpickable_fn, |
|
|
_deprecation_warning, |
|
|
StreamWrapper, |
|
|
validate_input_col |
|
|
) |
|
|
|
|
|
|
|
|
__all__ = ["FilterIterDataPipe", ] |
|
|
|
|
|
T_co = TypeVar('T_co', covariant=True) |
|
|
|
|
|
|
|
|
@functional_datapipe('filter') |
|
|
class FilterIterDataPipe(IterDataPipe[T_co]): |
|
|
r""" |
|
|
Filters out elements from the source datapipe according to input ``filter_fn`` (functional name: ``filter``). |
|
|
|
|
|
Args: |
|
|
datapipe: Iterable DataPipe being filtered |
|
|
filter_fn: Customized function mapping an element to a boolean. |
|
|
drop_empty_batches (Deprecated): By default, drops a batch if it is empty after filtering instead of keeping an empty list |
|
|
input_col: Index or indices of data which ``filter_fn`` is applied, such as: |
|
|
|
|
|
- ``None`` as default to apply ``filter_fn`` to the data directly. |
|
|
- Integer(s) is used for list/tuple. |
|
|
- Key(s) is used for dict. |
|
|
|
|
|
Example: |
|
|
>>> # xdoctest: +SKIP |
|
|
>>> from torchdata.datapipes.iter import IterableWrapper |
|
|
>>> def is_even(n): |
|
|
... return n % 2 == 0 |
|
|
>>> dp = IterableWrapper(range(5)) |
|
|
>>> filter_dp = dp.filter(filter_fn=is_even) |
|
|
>>> list(filter_dp) |
|
|
[0, 2, 4] |
|
|
""" |
|
|
datapipe: IterDataPipe |
|
|
filter_fn: Callable |
|
|
drop_empty_batches: bool |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
datapipe: IterDataPipe, |
|
|
filter_fn: Callable, |
|
|
drop_empty_batches: Optional[bool] = None, |
|
|
input_col=None, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.datapipe = datapipe |
|
|
|
|
|
_check_unpickable_fn(filter_fn) |
|
|
self.filter_fn = filter_fn |
|
|
|
|
|
if drop_empty_batches is None: |
|
|
drop_empty_batches = True |
|
|
else: |
|
|
_deprecation_warning( |
|
|
type(self).__name__, |
|
|
deprecation_version="1.12", |
|
|
removal_version="1.14", |
|
|
old_argument_name="drop_empty_batches", |
|
|
) |
|
|
self.drop_empty_batches = drop_empty_batches |
|
|
|
|
|
self.input_col = input_col |
|
|
validate_input_col(filter_fn, input_col) |
|
|
|
|
|
def _apply_filter_fn(self, data) -> bool: |
|
|
if self.input_col is None: |
|
|
return self.filter_fn(data) |
|
|
elif isinstance(self.input_col, (list, tuple)): |
|
|
args = tuple(data[col] for col in self.input_col) |
|
|
return self.filter_fn(*args) |
|
|
else: |
|
|
return self.filter_fn(data[self.input_col]) |
|
|
|
|
|
def __iter__(self) -> Iterator[T_co]: |
|
|
for data in self.datapipe: |
|
|
filtered = self._returnIfTrue(data) |
|
|
if self._isNonEmpty(filtered): |
|
|
yield filtered |
|
|
else: |
|
|
StreamWrapper.close_streams(data) |
|
|
|
|
|
def _returnIfTrue(self, data): |
|
|
condition = self._apply_filter_fn(data) |
|
|
|
|
|
if df_wrapper.is_column(condition): |
|
|
|
|
|
result = [] |
|
|
for idx, mask in enumerate(df_wrapper.iterate(condition)): |
|
|
if mask: |
|
|
result.append(df_wrapper.get_item(data, idx)) |
|
|
if len(result): |
|
|
return df_wrapper.concat(result) |
|
|
else: |
|
|
return None |
|
|
|
|
|
if not isinstance(condition, bool): |
|
|
raise ValueError("Boolean output is required for `filter_fn` of FilterIterDataPipe, got", type(condition)) |
|
|
if condition: |
|
|
return data |
|
|
|
|
|
def _isNonEmpty(self, data): |
|
|
if df_wrapper.is_dataframe(data): |
|
|
return True |
|
|
r = data is not None and \ |
|
|
not (isinstance(data, list) and len(data) == 0 and self.drop_empty_batches) |
|
|
return r |
|
|
|