File size: 861 Bytes
26225c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
from typing import Union, List
from torch_geometric.transforms import BaseTransform
from src.data import Data
__all__ = ['Transform']
class Transform(BaseTransform):
"""Transform on `_IN_TYPE` returning `_OUT_TYPE`."""
_IN_TYPE = Data
_OUT_TYPE = Data
_NO_REPR = []
def _process(self, x: _IN_TYPE):
raise NotImplementedError
def __call__(self, x: Union[_IN_TYPE, List]):
assert isinstance(x, (self._IN_TYPE, list))
if isinstance(x, list):
return [self.__call__(e) for e in x]
return self._process(x)
@property
def _repr_dict(self):
return {k: v for k, v in self.__dict__.items() if k not in self._NO_REPR}
def __repr__(self):
attr_repr = ', '.join([f'{k}={v}' for k, v in self._repr_dict.items()])
return f'{self.__class__.__name__}({attr_repr})'
|