File size: 6,751 Bytes
9dd3461
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# This base template ("datapipe.pyi.in") is generated from mypy stubgen with minimal editing for code injection
# The output file will be "datapipe.pyi". This is executed as part of torch/CMakeLists.txt
# Note that, for mypy, .pyi file takes precedent over .py file, such that we must define the interface for other
# classes/objects here, even though we are not injecting extra code into them at the moment.

from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta
from torch.utils.data.datapipes._hook_iterator import _SnapshotState
from typing import Any, Callable, Dict, Generic, Iterator, List, Optional, TypeVar, Union
from torch.utils.data import Dataset, IterableDataset, default_collate

T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')
UNTRACABLE_DATAFRAME_PIPES: Any


class MapDataPipe(Dataset[T_co], metaclass=_DataPipeMeta):
    functions: Dict[str, Callable] = ...
    reduce_ex_hook: Optional[Callable] = ...
    getstate_hook: Optional[Callable] = ...
    str_hook: Optional[Callable] = ...
    repr_hook: Optional[Callable] = ...
    def __getattr__(self, attribute_name: Any): ...
    @classmethod
    def register_function(cls, function_name: Any, function: Any) -> None: ...
    @classmethod
    def register_datapipe_as_function(cls, function_name: Any, cls_to_register: Any): ...
    def __getstate__(self): ...
    def __reduce_ex__(self, *args: Any, **kwargs: Any): ...
    @classmethod
    def set_getstate_hook(cls, hook_fn: Any) -> None: ...
    @classmethod
    def set_reduce_ex_hook(cls, hook_fn: Any) -> None: ...
    # Functional form of 'BatcherMapDataPipe'
    def batch(self, batch_size: int, drop_last: bool = False, wrapper_class=DataChunk) -> MapDataPipe: ...
    # Functional form of 'ConcaterMapDataPipe'
    def concat(self, *datapipes: MapDataPipe) -> MapDataPipe: ...
    # Functional form of 'MapperMapDataPipe'
    def map(self, fn: Callable= ...) -> MapDataPipe: ...
    # Functional form of 'ShufflerIterDataPipe'
    def shuffle(self, *, indices: Optional[List] = None) -> IterDataPipe: ...
    # Functional form of 'ZipperMapDataPipe'
    def zip(self, *datapipes: MapDataPipe[T_co]) -> MapDataPipe: ...


class IterDataPipe(IterableDataset[T_co], metaclass=_IterDataPipeMeta):
    functions: Dict[str, Callable] = ...
    reduce_ex_hook: Optional[Callable] = ...
    getstate_hook: Optional[Callable] = ...
    str_hook: Optional[Callable] = ...
    repr_hook: Optional[Callable] = ...
    _number_of_samples_yielded: int = ...
    _snapshot_state: _SnapshotState = _SnapshotState.Iterating
    _fast_forward_iterator: Optional[Iterator] = ...
    def __getattr__(self, attribute_name: Any): ...
    @classmethod
    def register_function(cls, function_name: Any, function: Any) -> None: ...
    @classmethod
    def register_datapipe_as_function(cls, function_name: Any, cls_to_register: Any, enable_df_api_tracing: bool = ...): ...
    def __getstate__(self): ...
    def __reduce_ex__(self, *args: Any, **kwargs: Any): ...
    @classmethod
    def set_getstate_hook(cls, hook_fn: Any) -> None: ...
    @classmethod
    def set_reduce_ex_hook(cls, hook_fn: Any) -> None: ...
    # Functional form of 'BatcherIterDataPipe'
    def batch(self, batch_size: int, drop_last: bool = False, wrapper_class=DataChunk) -> IterDataPipe: ...
    # Functional form of 'CollatorIterDataPipe'
    def collate(self, conversion: Optional[Union[Callable[..., Any],Dict[Union[str, Any], Union[Callable, Any]],]] = default_collate, collate_fn: Optional[Callable] = None) -> IterDataPipe: ...
    # Functional form of 'ConcaterIterDataPipe'
    def concat(self, *datapipes: IterDataPipe) -> IterDataPipe: ...
    # Functional form of 'DemultiplexerIterDataPipe'
    def demux(self, num_instances: int, classifier_fn: Callable[[T_co], Optional[int]], drop_none: bool = False, buffer_size: int = 1000) -> List[IterDataPipe]: ...
    # Functional form of 'FilterIterDataPipe'
    def filter(self, filter_fn: Callable, drop_empty_batches: Optional[bool] = None, input_col=None) -> IterDataPipe: ...
    # Functional form of 'ForkerIterDataPipe'
    def fork(self, num_instances: int, buffer_size: int = 1000) -> List[IterDataPipe]: ...
    # Functional form of 'GrouperIterDataPipe'
    def groupby(self, group_key_fn: Callable, *, buffer_size: int = 10000, group_size: Optional[int] = None, guaranteed_group_size: Optional[int] = None, drop_remaining: bool = False) -> IterDataPipe: ...
    # Functional form of 'FileListerIterDataPipe'
    def list_files(self, masks: Union[str, List[str]] = '', *, recursive: bool = False, abspath: bool = False, non_deterministic: bool = False, length: int = -1) -> IterDataPipe: ...
    # Functional form of 'MapperIterDataPipe'
    def map(self, fn: Callable, input_col=None, output_col=None) -> IterDataPipe: ...
    # Functional form of 'MultiplexerIterDataPipe'
    def mux(self, *datapipes) -> IterDataPipe: ...
    # Functional form of 'FileOpenerIterDataPipe'
    def open_files(self, mode: str = 'r', encoding: Optional[str] = None, length: int = -1) -> IterDataPipe: ...
    # Functional form of 'StreamReaderIterDataPipe'
    def read_from_stream(self, chunk=None) -> IterDataPipe: ...
    # Functional form of 'RoutedDecoderIterDataPipe'
    def routed_decode(self, *handlers: Callable, key_fn: Callable= ...) -> IterDataPipe: ...
    # Functional form of 'ShardingFilterIterDataPipe'
    def sharding_filter(self) -> IterDataPipe: ...
    # Functional form of 'ShufflerIterDataPipe'
    def shuffle(self, *, buffer_size: int = 10000, unbatch_level: int = 0) -> IterDataPipe: ...
    # Functional form of 'UnBatcherIterDataPipe'
    def unbatch(self, unbatch_level: int = 1) -> IterDataPipe: ...
    # Functional form of 'ZipperIterDataPipe'
    def zip(self, *datapipes: IterDataPipe) -> IterDataPipe: ...


class DFIterDataPipe(IterDataPipe):
    def _is_dfpipe(self): ...


class _DataPipeSerializationWrapper:
    def __init__(self, datapipe): ...
    def __getstate__(self): ...
    def __setstate__(self, state): ...
    def __len__(self): ...


class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe):
    def __iter__(self): ...


class _MapDataPipeSerializationWrapper(_DataPipeSerializationWrapper, MapDataPipe):
    def __getitem__(self, idx): ...


class DataChunk(list, Generic[T]):
    def __init__(self, items):
        super().__init__(items)
        self.items = items

    def as_str(self, indent=''):
        res = indent + "[" + ", ".join(str(i) for i in iter(self)) + "]"
        return res

    def __iter__(self) -> Iterator[T]:
        for i in super().__iter__():
            yield i

    def raw_iterator(self) -> T:  # type: ignore[misc]
        for i in self.items:
            yield i