| | import torch |
| | from src.data import Data, NAG, CSRData |
| | from src.transforms import Transform |
| | from src.utils import tensor_idx, to_float_rgb, to_byte_rgb, dropout, \ |
| | sanitize_keys |
| |
|
| |
|
| | __all__ = [ |
| | 'DataToNAG', 'NAGToData', 'Cast', 'NAGCast', 'RemoveKeys', 'NAGRemoveKeys', |
| | 'AddKeysTo', 'NAGAddKeysTo', 'NAGSelectByKey', 'SelectColumns', |
| | 'NAGSelectColumns', 'DropoutColumns', 'NAGDropoutColumns', 'DropoutRows', |
| | 'NAGDropoutRows', 'NAGJitterKey'] |
| |
|
| |
|
| | class DataToNAG(Transform): |
| | """Convert Data to a single-level NAG.""" |
| |
|
| | _IN_TYPE = Data |
| | _OUT_TYPE = NAG |
| |
|
| | def _process(self, data): |
| | return NAG([data]) |
| |
|
| |
|
| | class NAGToData(Transform): |
| | """Convert a single-level NAG to Data.""" |
| |
|
| | _IN_TYPE = NAG |
| | _OUT_TYPE = Data |
| |
|
| | def _process(self, nag): |
| | assert nag.num_levels == 1 |
| | return nag[0] |
| |
|
| |
|
| | class Cast(Transform): |
| | """Cast Data attributes to the provided integer and floating point |
| | dtypes. In case 'rgb' or 'mean_rgb' is found, `rgb_to_float` will |
| | decide whether it should be cast to 'fp_dtype' or 'uint8'. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | fp_dtype=torch.float, |
| | int_dtype=torch.long, |
| | rgb_to_float=True): |
| | self.fp_dtype = fp_dtype |
| | self.int_dtype = int_dtype |
| | self.rgb_to_float = rgb_to_float |
| |
|
| | def _process(self, data): |
| | for k in data.keys: |
| |
|
| | |
| | |
| | if isinstance(data[k], CSRData): |
| | values = [] |
| | for v in data[k].values: |
| | values.append(self._process(Data(foo=v)).foo) |
| | data[k].values = values |
| | data[k].pointers = data[k].pointers.long() |
| | continue |
| |
|
| | |
| | if k in ['rgb', 'mean_rgb']: |
| | data[k] = to_float_rgb(data[k]).to(self.fp_dtype)\ |
| | if self.rgb_to_float else to_byte_rgb(data[k]) |
| | continue |
| |
|
| | |
| | if isinstance(data[k], torch.Tensor): |
| | data[k] = data[k].to(self.fp_dtype) \ |
| | if data[k].is_floating_point() \ |
| | else data[k].to(self.int_dtype) |
| | continue |
| |
|
| | |
| |
|
| | return data |
| |
|
| |
|
| | class NAGCast(Cast): |
| | """Cast NAG attributes to the provided integer and floating point |
| | dtypes. In case 'rgb' or 'mean_rgb' is found and is not a floating |
| | point tensor, `rgb_to_float` will decide whether it should be cast |
| | to floats. |
| | """ |
| |
|
| | _IN_TYPE = NAG |
| | _OUT_TYPE = NAG |
| |
|
| | def _process(self, nag): |
| | transform = Cast( |
| | fp_dtype=self.fp_dtype, |
| | int_dtype=self.int_dtype, |
| | rgb_to_float=self.rgb_to_float) |
| |
|
| | for i_level in range(nag.num_levels): |
| | nag._list[i_level] = transform(nag[i_level]) |
| |
|
| | return nag |
| |
|
| |
|
| | class RemoveKeys(Transform): |
| | """Remove attributes of a Data object based on their name. |
| | |
| | :param keys: str of list(str) |
| | List of attribute names |
| | :param strict: bool |
| | If True, will raise an exception if an attribute from key is |
| | not within the input Data keys |
| | """ |
| |
|
| | _NO_REPR = ['strict'] |
| |
|
| | def __init__(self, keys=None, strict=False): |
| | self.keys = sanitize_keys(keys, default=[]) |
| | self.strict = strict |
| |
|
| | def _process(self, data): |
| | keys = set(data.keys) |
| | for k in self.keys: |
| | if k not in keys and self.strict: |
| | raise Exception(f"key: {k} is not within Data keys: {keys}") |
| | for k in self.keys: |
| | delattr(data, k) |
| | return data |
| |
|
| |
|
| | class NAGRemoveKeys(Transform): |
| | """Remove attributes of a NAG object based on their name. |
| | |
| | :param level: int or str |
| | Level at which to remove attributes. Can be an int or a str. If |
| | the latter, 'all' will apply on all levels, 'i+' will apply on |
| | level-i and above, 'i-' will apply on level-i and below |
| | :param keys: str or list(str) |
| | List of attribute names |
| | :param strict: bool=False |
| | If True, will raise an exception if an attribute from key is |
| | not within the input Data keys |
| | """ |
| |
|
| | _IN_TYPE = NAG |
| | _OUT_TYPE = NAG |
| | _NO_REPR = ['strict'] |
| |
|
| | def __init__(self, level='all', keys=None, strict=False): |
| | assert isinstance(level, (int, str)) |
| | self.level = level |
| | self.keys = sanitize_keys(keys, default=[]) |
| | self.strict = strict |
| |
|
| | def _process(self, nag): |
| |
|
| | level_keys = [[]] * nag.num_levels |
| | if isinstance(self.level, int): |
| | level_keys[self.level] = self.keys |
| | elif self.level == 'all': |
| | level_keys = [self.keys] * nag.num_levels |
| | elif self.level[-1] == '+': |
| | i = int(self.level[:-1]) |
| | level_keys[i:] = [self.keys] * (nag.num_levels - i) |
| | elif self.level[-1] == '-': |
| | i = int(self.level[:-1]) |
| | level_keys[:i] = [self.keys] * i |
| | else: |
| | raise ValueError(f'Unsupported level={self.level}') |
| |
|
| | transforms = [ |
| | RemoveKeys(keys=k, strict=self.strict) for k in level_keys] |
| |
|
| | for i_level in range(nag.num_levels): |
| | nag._list[i_level] = transforms[i_level](nag._list[i_level]) |
| |
|
| | return nag |
| |
|
| |
|
| | class AddKeysTo(Transform): |
| | """Get attributes from their keys and concatenate them to x. |
| | |
| | :param keys: str or list(str) |
| | The feature concatenated to 'to' |
| | :param to: str |
| | Destination attribute where the features in 'keys' will be |
| | concatenated |
| | :param strict: bool |
| | Whether we want to raise an error if a key is not found |
| | :param delete_after: bool |
| | Whether the Data attributes should be removed once added to 'to' |
| | """ |
| |
|
| | _NO_REPR = ['strict'] |
| |
|
| | def __init__(self, keys=None, to='x', strict=True, delete_after=True): |
| | self.keys = [keys] if isinstance(keys, str) else keys |
| | self.to = to |
| | self.strict = strict |
| | self.delete_after = delete_after |
| |
|
| | def _process_single_key(self, data, key, to): |
| | |
| | feat = getattr(data, key, None) |
| | x = getattr(data, to, None) |
| |
|
| | |
| | if feat is None: |
| | if self.strict: |
| | raise Exception(f"Data should contain the attribute '{key}'") |
| | else: |
| | return data |
| |
|
| | |
| | if self.delete_after: |
| | delattr(data, key) |
| |
|
| | |
| | if x is None: |
| | if self.strict and data.num_nodes != feat.shape[0]: |
| | raise Exception(f"Data should contain the attribute '{to}'") |
| | if feat.dim() == 1: |
| | feat = feat.unsqueeze(-1) |
| | data[to] = feat |
| | return data |
| |
|
| | |
| | if x.shape[0] != feat.shape[0]: |
| | raise Exception( |
| | f"The tensors '{to}' and '{key}' can't be concatenated, " |
| | f"'{to}': {x.shape[0]}, '{key}': {feat.shape[0]}") |
| |
|
| | |
| | if x.dim() == 1: |
| | x = x.unsqueeze(-1) |
| | if feat.dim() == 1: |
| | feat = feat.unsqueeze(-1) |
| | data[to] = torch.cat([x, feat], dim=-1) |
| |
|
| | return data |
| |
|
| | def _process(self, data): |
| | if self.keys is None or len(self.keys) == 0: |
| | return data |
| |
|
| | for key in self.keys: |
| | data = self._process_single_key(data, key, self.to) |
| |
|
| | return data |
| |
|
| |
|
| | class NAGAddKeysTo(Transform): |
| | """Get attributes from their keys and concatenate them to x. |
| | |
| | :param level: int or str |
| | Level at which to remove attributes. Can be an int or a str. If |
| | the latter, 'all' will apply on all levels, 'i+' will apply on |
| | level-i and above, 'i-' will apply on level-i and below |
| | :param keys: str or list(str) |
| | The feature concatenated to 'to' |
| | :param to: str |
| | Destination attribute where the features in 'keys' will be |
| | concatenated |
| | :param strict: bool |
| | Whether we want to raise an error if a key is not found |
| | :param delete_after: bool |
| | Whether the Data attributes should be removed once added to 'to' |
| | """ |
| |
|
| | _IN_TYPE = NAG |
| | _OUT_TYPE = NAG |
| | _NO_REPR = ['strict'] |
| |
|
| | def __init__( |
| | self, level='all', keys=None, to='x', strict=True, |
| | delete_after=True): |
| | self.level = level |
| | self.keys = [keys] if isinstance(keys, str) else keys |
| | self.to = to |
| | self.strict = strict |
| | self.delete_after = delete_after |
| |
|
| | def _process(self, nag): |
| |
|
| | level_keys = [[]] * nag.num_levels |
| | if isinstance(self.level, int): |
| | level_keys[self.level] = self.keys |
| | elif self.level == 'all': |
| | level_keys = [self.keys] * nag.num_levels |
| | elif self.level[-1] == '+': |
| | i = int(self.level[:-1]) |
| | level_keys[i:] = [self.keys] * (nag.num_levels - i) |
| | elif self.level[-1] == '-': |
| | i = int(self.level[:-1]) |
| | level_keys[:i] = [self.keys] * i |
| | else: |
| | raise ValueError(f'Unsupported level={self.level}') |
| |
|
| | transforms = [ |
| | AddKeysTo( |
| | keys=k, to=self.to, strict=self.strict, |
| | delete_after=self.delete_after) |
| | for k in level_keys] |
| |
|
| | for i_level in range(nag.num_levels): |
| | nag._list[i_level] = transforms[i_level](nag._list[i_level]) |
| |
|
| | return nag |
| |
|
| |
|
| | class NAGSelectByKey(Transform): |
| | """Select the i-level nodes based on a key. The corresponding key is |
| | expected to exist in the i-level attributes and should hold a 1D |
| | boolean mask. |
| | |
| | :param key: str |
| | Key attribute expected to be found in the input NAG's `level`. |
| | The `key` attribute should carry a 1D boolean mask over the |
| | `level` nodes |
| | :param level: int |
| | NAG level based on which to operate the selection |
| | :param negation: bool |
| | Whether the mask or its complementary should be used |
| | :param strict: bool |
| | Whether we want to raise an error if the key is not found or if |
| | it does not carry a 1D boolean mask |
| | :param delete_after: bool |
| | Whether the `key` attribute should be removed after selection |
| | """ |
| |
|
| | _IN_TYPE = NAG |
| | _OUT_TYPE = NAG |
| | _NO_REPR = ['strict'] |
| |
|
| | def __init__( |
| | self, key=None, level=0, negation=False, strict=True, |
| | delete_after=True): |
| | assert key is not None |
| | self.key = key |
| | self.level = level |
| | self.negation = negation |
| | self.strict = strict |
| | self.delete_after = delete_after |
| |
|
| | def _process(self, nag): |
| | |
| | if self.key not in nag[self.level].keys: |
| | if self.strict: |
| | raise ValueError( |
| | f'Input NAG does not have `{self.key}` attribute at ' |
| | f'level `{self.level}`') |
| | return nag |
| |
|
| | |
| | mask = nag[self.level][self.key] |
| |
|
| | |
| | dtype = mask.dtype |
| | if dtype != torch.bool: |
| | if self.strict: |
| | raise ValueError( |
| | f'`{self.key}` attribute has dtype={dtype} but ' |
| | f'dtype=torch.bool was expected') |
| | return nag |
| |
|
| | |
| | expected_size = torch.Size((nag[self.level].num_nodes,)) |
| | actual_size = mask.shape |
| | if expected_size != actual_size: |
| | if self.strict: |
| | raise ValueError( |
| | f'`{self.key}` attribute has shape={actual_size} but ' |
| | f'shape={expected_size} was expected') |
| | return nag |
| |
|
| | |
| | mask = ~mask if self.negation else mask |
| | nag = nag.select(self.level, torch.where(mask)[0]) |
| |
|
| | |
| | if self.delete_after: |
| | nag[self.level][self.key] = None |
| |
|
| | return nag |
| |
|
| |
|
| | class SelectColumns(Transform): |
| | """Select columns of an attribute based on their indices. |
| | |
| | :param key: str |
| | The Data attribute whose columns should be selected |
| | :param idx: int, Tensor or list |
| | The indices of the edge features to keep. If None, this |
| | transform will have no effect and edge features will be left |
| | untouched |
| | """ |
| |
|
| | def __init__(self, key=None, idx=None): |
| | assert key is not None, f"A Data key must be specified" |
| | self.key = key |
| | self.idx = tensor_idx(idx) if idx is not None else None |
| |
|
| | def _process(self, data): |
| | if self.idx is None or getattr(data, self.key, None) is None: |
| | return data |
| | idx = tensor_idx(torch.as_tensor(self.idx, device=data.device)) |
| | data[self.key] = data[self.key][:, idx] |
| | return data |
| |
|
| |
|
| | class NAGSelectColumns(Transform): |
| | """Select columns of an attribute based on their indices. |
| | |
| | :param level: int or str |
| | Level at which to select attributes. Can be an int or a str. If |
| | the latter, 'all' will apply on all levels, 'i+' will apply on |
| | level-i and above, 'i-' will apply on level-i and below |
| | :param key: str |
| | The Data attribute whose columns should be selected |
| | :param idx: int, Tensor or list |
| | The indices of the edge features to keep. If None, this |
| | transform will have no effect and edge features will be left |
| | untouched |
| | """ |
| |
|
| | _IN_TYPE = NAG |
| | _OUT_TYPE = NAG |
| |
|
| | def __init__(self, level='all', key=None, idx=None): |
| | self.level = level |
| | self.key = key |
| | self.idx = idx |
| |
|
| | def _process(self, nag): |
| |
|
| | level_idx = [None] * nag.num_levels |
| | if isinstance(self.level, int): |
| | level_idx[self.level] = self.idx |
| | elif self.level == 'all': |
| | level_idx = [self.idx] * nag.num_levels |
| | elif self.level[-1] == '+': |
| | i = int(self.level[:-1]) |
| | level_idx[i:] = [self.idx] * (nag.num_levels - i) |
| | elif self.level[-1] == '-': |
| | i = int(self.level[:-1]) |
| | level_idx[:i] = [self.idx] * i |
| | else: |
| | raise ValueError(f'Unsupported level={self.level}') |
| |
|
| | transforms = [SelectColumns(key=self.key, idx=idx) for idx in level_idx] |
| |
|
| | for i_level in range(nag.num_levels): |
| | nag._list[i_level] = transforms[i_level](nag._list[i_level]) |
| |
|
| | return nag |
| |
|
| |
|
| | class DropoutColumns(Transform): |
| | """Randomly set a Data attribute column to 0. |
| | |
| | :param p: float |
| | Probability of a column to be dropped |
| | :param key: str |
| | The Data attribute whose columns should be selected |
| | :param inplace: bool |
| | Whether the dropout should be performed directly on the input |
| | or on a copy of it |
| | :param to_mean: bool |
| | Whether the dropped values should be set to the mean of their |
| | corresponding column (dim=1) or to zero (default) |
| | """ |
| |
|
| | def __init__(self, p=0.5, key=None, inplace=False, to_mean=False): |
| | assert key is not None, f"A Data key must be specified" |
| | self.p = p |
| | self.key = key |
| | self.inplace = inplace |
| | self.to_mean = to_mean |
| |
|
| | def _process(self, data): |
| | |
| | if self.p <= 0: |
| | return data |
| |
|
| | |
| | if getattr(data, self.key, None) is None: |
| | return data |
| |
|
| | |
| | data[self.key] = dropout( |
| | data[self.key], p=self.p, dim=1, inplace=self.inplace, |
| | to_mean=self.to_mean) |
| |
|
| | return data |
| |
|
| |
|
| | class NAGDropoutColumns(Transform): |
| | """Randomly set a Data attribute column to 0. |
| | |
| | :param level: int or str |
| | Level at which to drop columns. Can be an int or a str. If |
| | the latter, 'all' will apply on all levels, 'i+' will apply on |
| | level-i and above, 'i-' will apply on level-i and below |
| | :param p: float |
| | Probability of a column to be dropped |
| | :param key: str |
| | The Data attribute whose columns should be selected |
| | :param inplace: bool |
| | Whether the dropout should be performed directly on the input |
| | or on a copy of it |
| | :param to_mean: bool |
| | Whether the dropped values should be set to the mean of their |
| | corresponding column (dim=1) or to zero (default) |
| | """ |
| |
|
| | _IN_TYPE = NAG |
| | _OUT_TYPE = NAG |
| |
|
| | def __init__( |
| | self, level='all', p=0.5, key=None, inplace=False, to_mean=False): |
| | assert isinstance(level, int) or level == 'all' or level.endswith('-') \ |
| | or level.endswith('+') |
| | self.level = level |
| | self.p = p |
| | self.key = key |
| | self.inplace = inplace |
| | self.to_mean = to_mean |
| |
|
| | def _process(self, nag): |
| | |
| | if self.p <= 0: |
| | return nag |
| |
|
| | if isinstance(self.level, int): |
| | levels = [self.level] |
| | elif self.level == 'all': |
| | levels = range(0, nag.num_levels) |
| | elif self.level[-1] == '+': |
| | levels = range(int(self.level[:-1]), nag.num_levels) |
| | elif self.level[-1] == '-': |
| | levels = range(0, int(self.level[:-1]) + 1) |
| | else: |
| | return nag |
| |
|
| | for i_level in levels: |
| | |
| | if getattr(nag[i_level], self.key, None) is None: |
| | continue |
| |
|
| | |
| | nag[i_level][self.key] = dropout( |
| | nag[i_level][self.key], p=self.p, dim=1, inplace=self.inplace, |
| | to_mean=self.to_mean) |
| |
|
| | return nag |
| |
|
| |
|
| | class DropoutRows(Transform): |
| | """Randomly set a Data attribute rows to 0. |
| | |
| | :param p: float |
| | Probability of a row to be dropped |
| | :param key: str |
| | The Data attribute whose rows should be selected |
| | :param inplace: bool |
| | Whether the dropout should be performed directly on the input |
| | or on a copy of it |
| | :param to_mean: bool |
| | Whether the dropped values should be set to the mean of their |
| | corresponding column (dim=1) or to zero (default) |
| | """ |
| |
|
| | def __init__(self, p=0.5, key=None, inplace=False, to_mean=False): |
| | assert key is not None, f"A Data key must be specified" |
| | self.p = p |
| | self.key = key |
| | self.inplace = inplace |
| | self.to_mean = to_mean |
| |
|
| |
|
| | def _process(self, data): |
| | |
| | if self.p <= 0: |
| | return data |
| |
|
| | |
| | if getattr(data, self.key, None) is None: |
| | return data |
| |
|
| | |
| | data[self.key] = dropout( |
| | data[self.key], p=self.p, dim=0, inplace=self.inplace, |
| | to_mean=self.to_mean) |
| |
|
| | return data |
| |
|
| |
|
| | class NAGDropoutRows(Transform): |
| | """Randomly set a Data attribute rows to 0. |
| | |
| | :param level: int or str |
| | Level at which to drop rows. Can be an int or a str. If |
| | the latter, 'all' will apply on all levels, 'i+' will apply on |
| | level-i and above, 'i-' will apply on level-i and below |
| | :param p: float |
| | Probability of a row to be dropped |
| | :param key: str |
| | The Data attribute whose rows should be selected |
| | :param inplace: bool |
| | Whether the dropout should be performed directly on the input |
| | or on a copy of it |
| | :param to_mean: bool |
| | Whether the dropped values should be set to the mean of their |
| | corresponding column (dim=1) or to zero (default) |
| | """ |
| |
|
| | _IN_TYPE = NAG |
| | _OUT_TYPE = NAG |
| |
|
| | def __init__( |
| | self, level='all', p=0.5, key=None, inplace=False, to_mean=False): |
| | assert isinstance(level, int) or level == 'all' or level.endswith('-') \ |
| | or level.endswith('+') |
| | self.level = level |
| | self.p = p |
| | self.key = key |
| | self.inplace = inplace |
| | self.to_mean = to_mean |
| |
|
| | def _process(self, nag): |
| | |
| | if self.p <= 0: |
| | return nag |
| |
|
| | if isinstance(self.level, int): |
| | levels = [self.level] |
| | elif self.level == 'all': |
| | levels = range(0, nag.num_levels) |
| | elif self.level[-1] == '+': |
| | levels = range(int(self.level[:-1]), nag.num_levels) |
| | elif self.level[-1] == '-': |
| | levels = range(0, int(self.level[:-1]) + 1) |
| | else: |
| | return nag |
| |
|
| | for i_level in levels: |
| | |
| | if getattr(nag[i_level], self.key, None) is None: |
| | continue |
| |
|
| | |
| | nag[i_level][self.key] = dropout( |
| | nag[i_level][self.key], p=self.p, dim=0, inplace=self.inplace, |
| | to_mean=self.to_mean) |
| |
|
| | return nag |
| |
|
| |
|
| | class NAGJitterKey(Transform): |
| | """Add some gaussian noise to Data['key'] for all data in a NAG. |
| | |
| | :param key: str |
| | The attribute on which to apply jittering |
| | :param sigma: float or List(float) |
| | Standard deviation of the gaussian noise. A list may be passed |
| | to transform NAG levels with different parameters. Passing |
| | sigma <= 0 will prevent any jittering |
| | :param trunc: float or List(float) |
| | Standard deviation of the gaussian noise. A list may be passed |
| | to transform NAG levels with different parameters. Passing |
| | trunc <= 0 will not truncate the normal distribution |
| | :param strict: bool |
| | Whether an error should be raised if one of the input NAG levels |
| | does not have `key` attribute |
| | """ |
| |
|
| | _IN_TYPE = NAG |
| | _OUT_TYPE = NAG |
| |
|
| | def __init__(self, key=None, sigma=0.01, trunc=0.05, strict=False): |
| | assert key is not None, "A key must be specified" |
| | assert isinstance(sigma, (int, float, list)) |
| | assert isinstance(trunc, (int, float, list)) |
| | self.key = key |
| | self.sigma = sigma |
| | self.trunc = trunc |
| | self.strict = strict |
| |
|
| | def _process(self, nag): |
| | if not isinstance(self.sigma, list): |
| | sigma = [self.sigma] * nag.num_levels |
| | else: |
| | sigma = self.sigma |
| |
|
| | if not isinstance(self.trunc, list): |
| | trunc = [self.trunc] * nag.num_levels |
| | else: |
| | trunc = self.trunc |
| |
|
| | for i_level in range(nag.num_levels): |
| |
|
| | if sigma[i_level] <= 0: |
| | continue |
| |
|
| | if getattr(nag[i_level], self.key, None) is None: |
| | if self.strict: |
| | raise ValueError( |
| | f"Input data does not have any '{self.key} attribute") |
| | else: |
| | continue |
| |
|
| | if trunc[i_level] > 0: |
| | noise = torch.nn.init.trunc_normal_( |
| | torch.empty_like(nag[i_level][self.key]), |
| | mean=0., |
| | std=sigma[i_level], |
| | a=-trunc[i_level], |
| | b=trunc[i_level]) |
| | else: |
| | noise = torch.randn_like( |
| | nag[i_level][self.key]) * sigma[i_level] |
| |
|
| | nag[i_level][self.key] += noise |
| |
|
| | return nag |
| |
|