| | import sys |
| | import copy |
| | from .transforms import * |
| | from .data import * |
| | from .device import * |
| | from .sampling import * |
| | from .neighbors import * |
| | from .point import * |
| | from .graph import * |
| | from .geometry import * |
| | from .partition import * |
| | from .instance import * |
| | from .debug import * |
| | from src.data import Data |
| | import torch_geometric.transforms as pygT |
| | from omegaconf import OmegaConf |
| |
|
| |
|
| | |
| | |
| | |
| | _spt_tr = sys.modules[__name__] |
| | _pyg_tr = sys.modules["torch_geometric.transforms"] |
| |
|
| | _intersection_tr = set(_spt_tr.__dict__) & set(_pyg_tr.__dict__) |
| | _intersection_tr = set([t for t in _intersection_tr if not t.startswith("_")]) |
| | _intersection_cls = [] |
| |
|
| | for name in _intersection_tr: |
| | cls = getattr(_spt_tr, name) |
| | if not "torch_geometric.transforms." in str(cls): |
| | _intersection_cls.append(cls) |
| |
|
| | if len(_intersection_tr) > 0: |
| | if len(_intersection_cls) > 0: |
| | raise Exception( |
| | f"It seems that you are overriding a transform from pytorch " |
| | f"geometric, this is forbidden, please rename your classes " |
| | f"{_intersection_tr} from {_intersection_cls}") |
| | else: |
| | raise Exception( |
| | f"It seems you are importing transforms {_intersection_tr} " |
| | f"from pytorch geometric within the current code base. Please, " |
| | f"remove them or add them within a class, function, etc.") |
| |
|
| |
|
| | def instantiate_transform(transform_option, attr="transform"): |
| | """Create a transform from an OmegaConf dict such as: |
| | |
| | ```yaml |
| | transform: GridSampling3D |
| | params: |
| | size: 0.01 |
| | ``` |
| | """ |
| | |
| | tr_name = getattr(transform_option, attr, None) |
| |
|
| | |
| | cls = getattr(_spt_tr, tr_name, None) |
| | if not cls: |
| | cls = getattr(_pyg_tr, tr_name, None) |
| | if not cls: |
| | raise ValueError(f"Transform {tr_name} is nowhere to be found") |
| |
|
| | |
| | try: |
| | tr_params = transform_option.get('params') |
| | if tr_params is not None: |
| | tr_params = OmegaConf.to_container(tr_params, resolve=True) |
| | except KeyError: |
| | tr_params = None |
| | try: |
| | lparams = transform_option.get('lparams') |
| | if lparams is not None: |
| | lparams = OmegaConf.to_container(lparams, resolve=True) |
| | except KeyError: |
| | lparams = None |
| |
|
| | |
| | if tr_params and lparams: |
| | return cls(*lparams, **tr_params) |
| | if tr_params: |
| | return cls(**tr_params) |
| | if lparams: |
| | return cls(*lparams) |
| | return cls() |
| |
|
| |
|
| | def instantiate_transforms(transform_options): |
| | """Create a torch_geometric composite transform from an OmegaConf |
| | list such as: |
| | |
| | ```yaml |
| | - transform: GridSampling3D |
| | params: |
| | size: 0.01 |
| | - transform: NormaliseScale |
| | ``` |
| | """ |
| | transforms = [] |
| | for transform in transform_options: |
| | transforms.append(instantiate_transform(transform)) |
| |
|
| | if len(transforms) <= 1: |
| | return pygT.Compose(transforms) |
| |
|
| | |
| | |
| | for i in range(1, len(transforms)): |
| | t_out = transforms[i - 1] |
| | t_in = transforms[i] |
| | out_type = getattr(t_out, '_OUT_TYPE', Data) |
| | in_type = getattr(t_in, '_IN_TYPE', Data) |
| | if in_type != out_type: |
| | raise ValueError( |
| | f"Cannot compose transforms: {t_out} returns a {out_type} " |
| | f"while {t_in} expects a {in_type} input.") |
| |
|
| | return pygT.Compose(transforms) |
| |
|
| |
|
| | def instantiate_datamodule_transforms(transform_options, log=None): |
| | """Create a dictionary of torch_geometric composite transforms from |
| | a datamodule OmegaConf holding lists of transforms characterized by |
| | a `*transform*` key such as: |
| | |
| | ```yaml |
| | # parsed in the output dictionary |
| | pre_transform: |
| | - transform: GridSampling3D |
| | params: |
| | size: 0.01 |
| | - transform: NormaliseScale |
| | |
| | # not parsed in the output dictionary |
| | foo: |
| | a: 1 |
| | b: 10 |
| | |
| | # parsed in the output dictionary |
| | on_device_transform: |
| | - transform: NodeSize |
| | - transform: NAGAddSelfLoops |
| | ``` |
| | |
| | This helper function is typically intended for instantiating the |
| | transforms of a `BaseDataModule` from an Omegaconf config object |
| | |
| | Credit: https://github.com/torch-points3d/torch-points3d |
| | """ |
| | transforms_dict = {} |
| | for key_name in transform_options.keys(): |
| | if "transform" not in key_name: |
| | continue |
| | name = key_name.replace("transforms", "transform") |
| | params = getattr(transform_options, key_name, None) |
| | if params is None: |
| | continue |
| | try: |
| | transform = instantiate_transforms(params) |
| | except Exception: |
| | msg = f"Error trying to create {name}, {params}" |
| | log.exception(msg) if log is not None else print(msg) |
| | continue |
| | transforms_dict[name] = transform |
| | if len(transforms_dict) == 0: |
| | msg = (f"Could not find any '*transform*' key among the provided config" |
| | f" keys: {transform_options.keys()}. Are you sure you passed a " |
| | f"datamodule config as input ?") |
| | log.exception(msg) if log is not None else print(msg) |
| | return transforms_dict |
| |
|
| |
|
| | def explode_transform(transform_list): |
| | """Extracts a flattened list of transforms from a Compose or from a |
| | list of transforms. |
| | """ |
| | out = [] |
| | if transform_list is not None: |
| | if isinstance(transform_list, pygT.Compose): |
| | out = copy.deepcopy(transform_list.transforms) |
| | elif isinstance(transform_list, list): |
| | out = copy.deepcopy(transform_list) |
| | else: |
| | raise Exception( |
| | "Transforms should be provided either within a list or " |
| | "a Compose") |
| | return out |
| |
|