|
|
import sys |
|
|
import copy |
|
|
from .transforms import * |
|
|
from .data import * |
|
|
from .device import * |
|
|
from .sampling import * |
|
|
from .neighbors import * |
|
|
from .point import * |
|
|
from .graph import * |
|
|
from .geometry import * |
|
|
from .partition import * |
|
|
from .instance import * |
|
|
from .debug import * |
|
|
from src.data import Data |
|
|
import torch_geometric.transforms as pygT |
|
|
from omegaconf import OmegaConf |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_spt_tr = sys.modules[__name__] |
|
|
_pyg_tr = sys.modules["torch_geometric.transforms"] |
|
|
|
|
|
_intersection_tr = set(_spt_tr.__dict__) & set(_pyg_tr.__dict__) |
|
|
_intersection_tr = set([t for t in _intersection_tr if not t.startswith("_")]) |
|
|
_intersection_cls = [] |
|
|
|
|
|
for name in _intersection_tr: |
|
|
cls = getattr(_spt_tr, name) |
|
|
if not "torch_geometric.transforms." in str(cls): |
|
|
_intersection_cls.append(cls) |
|
|
|
|
|
if len(_intersection_tr) > 0: |
|
|
if len(_intersection_cls) > 0: |
|
|
raise Exception( |
|
|
f"It seems that you are overriding a transform from pytorch " |
|
|
f"geometric, this is forbidden, please rename your classes " |
|
|
f"{_intersection_tr} from {_intersection_cls}") |
|
|
else: |
|
|
raise Exception( |
|
|
f"It seems you are importing transforms {_intersection_tr} " |
|
|
f"from pytorch geometric within the current code base. Please, " |
|
|
f"remove them or add them within a class, function, etc.") |
|
|
|
|
|
|
|
|
def instantiate_transform(transform_option, attr="transform"): |
|
|
"""Create a transform from an OmegaConf dict such as: |
|
|
|
|
|
```yaml |
|
|
transform: GridSampling3D |
|
|
params: |
|
|
size: 0.01 |
|
|
``` |
|
|
""" |
|
|
|
|
|
tr_name = getattr(transform_option, attr, None) |
|
|
|
|
|
|
|
|
cls = getattr(_spt_tr, tr_name, None) |
|
|
if not cls: |
|
|
cls = getattr(_pyg_tr, tr_name, None) |
|
|
if not cls: |
|
|
raise ValueError(f"Transform {tr_name} is nowhere to be found") |
|
|
|
|
|
|
|
|
try: |
|
|
tr_params = transform_option.get('params') |
|
|
if tr_params is not None: |
|
|
tr_params = OmegaConf.to_container(tr_params, resolve=True) |
|
|
except KeyError: |
|
|
tr_params = None |
|
|
try: |
|
|
lparams = transform_option.get('lparams') |
|
|
if lparams is not None: |
|
|
lparams = OmegaConf.to_container(lparams, resolve=True) |
|
|
except KeyError: |
|
|
lparams = None |
|
|
|
|
|
|
|
|
if tr_params and lparams: |
|
|
return cls(*lparams, **tr_params) |
|
|
if tr_params: |
|
|
return cls(**tr_params) |
|
|
if lparams: |
|
|
return cls(*lparams) |
|
|
return cls() |
|
|
|
|
|
|
|
|
def instantiate_transforms(transform_options): |
|
|
"""Create a torch_geometric composite transform from an OmegaConf |
|
|
list such as: |
|
|
|
|
|
```yaml |
|
|
- transform: GridSampling3D |
|
|
params: |
|
|
size: 0.01 |
|
|
- transform: NormaliseScale |
|
|
``` |
|
|
""" |
|
|
transforms = [] |
|
|
for transform in transform_options: |
|
|
transforms.append(instantiate_transform(transform)) |
|
|
|
|
|
if len(transforms) <= 1: |
|
|
return pygT.Compose(transforms) |
|
|
|
|
|
|
|
|
|
|
|
for i in range(1, len(transforms)): |
|
|
t_out = transforms[i - 1] |
|
|
t_in = transforms[i] |
|
|
out_type = getattr(t_out, '_OUT_TYPE', Data) |
|
|
in_type = getattr(t_in, '_IN_TYPE', Data) |
|
|
if in_type != out_type: |
|
|
raise ValueError( |
|
|
f"Cannot compose transforms: {t_out} returns a {out_type} " |
|
|
f"while {t_in} expects a {in_type} input.") |
|
|
|
|
|
return pygT.Compose(transforms) |
|
|
|
|
|
|
|
|
def instantiate_datamodule_transforms(transform_options, log=None): |
|
|
"""Create a dictionary of torch_geometric composite transforms from |
|
|
a datamodule OmegaConf holding lists of transforms characterized by |
|
|
a `*transform*` key such as: |
|
|
|
|
|
```yaml |
|
|
# parsed in the output dictionary |
|
|
pre_transform: |
|
|
- transform: GridSampling3D |
|
|
params: |
|
|
size: 0.01 |
|
|
- transform: NormaliseScale |
|
|
|
|
|
# not parsed in the output dictionary |
|
|
foo: |
|
|
a: 1 |
|
|
b: 10 |
|
|
|
|
|
# parsed in the output dictionary |
|
|
on_device_transform: |
|
|
- transform: NodeSize |
|
|
- transform: NAGAddSelfLoops |
|
|
``` |
|
|
|
|
|
This helper function is typically intended for instantiating the |
|
|
transforms of a `BaseDataModule` from an Omegaconf config object |
|
|
|
|
|
Credit: https://github.com/torch-points3d/torch-points3d |
|
|
""" |
|
|
transforms_dict = {} |
|
|
for key_name in transform_options.keys(): |
|
|
if "transform" not in key_name: |
|
|
continue |
|
|
name = key_name.replace("transforms", "transform") |
|
|
params = getattr(transform_options, key_name, None) |
|
|
if params is None: |
|
|
continue |
|
|
try: |
|
|
transform = instantiate_transforms(params) |
|
|
except Exception: |
|
|
msg = f"Error trying to create {name}, {params}" |
|
|
log.exception(msg) if log is not None else print(msg) |
|
|
continue |
|
|
transforms_dict[name] = transform |
|
|
if len(transforms_dict) == 0: |
|
|
msg = (f"Could not find any '*transform*' key among the provided config" |
|
|
f" keys: {transform_options.keys()}. Are you sure you passed a " |
|
|
f"datamodule config as input ?") |
|
|
log.exception(msg) if log is not None else print(msg) |
|
|
return transforms_dict |
|
|
|
|
|
|
|
|
def explode_transform(transform_list): |
|
|
"""Extracts a flattened list of transforms from a Compose or from a |
|
|
list of transforms. |
|
|
""" |
|
|
out = [] |
|
|
if transform_list is not None: |
|
|
if isinstance(transform_list, pygT.Compose): |
|
|
out = copy.deepcopy(transform_list.transforms) |
|
|
elif isinstance(transform_list, list): |
|
|
out = copy.deepcopy(transform_list) |
|
|
else: |
|
|
raise Exception( |
|
|
"Transforms should be provided either within a list or " |
|
|
"a Compose") |
|
|
return out |
|
|
|