English
Shanci's picture
Upload folder using huggingface_hub
26225c5 verified
import sys
import copy
from .transforms import *
from .data import *
from .device import *
from .sampling import *
from .neighbors import *
from .point import *
from .graph import *
from .geometry import *
from .partition import *
from .instance import *
from .debug import *
from src.data import Data
import torch_geometric.transforms as pygT
from omegaconf import OmegaConf
# Fuse all transforms defined in this project with the torch_geometric
# transforms. Special attention is given to local transforms that may
# have the same name as some torch_geometric transform
_spt_tr = sys.modules[__name__]
_pyg_tr = sys.modules["torch_geometric.transforms"]
_intersection_tr = set(_spt_tr.__dict__) & set(_pyg_tr.__dict__)
_intersection_tr = set([t for t in _intersection_tr if not t.startswith("_")])
_intersection_cls = []
for name in _intersection_tr:
cls = getattr(_spt_tr, name)
if not "torch_geometric.transforms." in str(cls):
_intersection_cls.append(cls)
if len(_intersection_tr) > 0:
if len(_intersection_cls) > 0:
raise Exception(
f"It seems that you are overriding a transform from pytorch "
f"geometric, this is forbidden, please rename your classes "
f"{_intersection_tr} from {_intersection_cls}")
else:
raise Exception(
f"It seems you are importing transforms {_intersection_tr} "
f"from pytorch geometric within the current code base. Please, "
f"remove them or add them within a class, function, etc.")
def instantiate_transform(transform_option, attr="transform"):
"""Create a transform from an OmegaConf dict such as:
```yaml
transform: GridSampling3D
params:
size: 0.01
```
"""
# Read the transform name
tr_name = getattr(transform_option, attr, None)
# Find the transform class corresponding to the name
cls = getattr(_spt_tr, tr_name, None)
if not cls:
cls = getattr(_pyg_tr, tr_name, None)
if not cls:
raise ValueError(f"Transform {tr_name} is nowhere to be found")
# Parse the transform arguments
try:
tr_params = transform_option.get('params') # Update to OmegaConf 2.0
if tr_params is not None:
tr_params = OmegaConf.to_container(tr_params, resolve=True)
except KeyError:
tr_params = None
try:
lparams = transform_option.get('lparams') # Update to OmegaConf 2.0
if lparams is not None:
lparams = OmegaConf.to_container(lparams, resolve=True)
except KeyError:
lparams = None
# Instantiate the transform
if tr_params and lparams:
return cls(*lparams, **tr_params)
if tr_params:
return cls(**tr_params)
if lparams:
return cls(*lparams)
return cls()
def instantiate_transforms(transform_options):
"""Create a torch_geometric composite transform from an OmegaConf
list such as:
```yaml
- transform: GridSampling3D
params:
size: 0.01
- transform: NormaliseScale
```
"""
transforms = []
for transform in transform_options:
transforms.append(instantiate_transform(transform))
if len(transforms) <= 1:
return pygT.Compose(transforms)
# If multiple transforms are composed, make sure the input and
# output match
for i in range(1, len(transforms)):
t_out = transforms[i - 1]
t_in = transforms[i]
out_type = getattr(t_out, '_OUT_TYPE', Data)
in_type = getattr(t_in, '_IN_TYPE', Data)
if in_type != out_type:
raise ValueError(
f"Cannot compose transforms: {t_out} returns a {out_type} "
f"while {t_in} expects a {in_type} input.")
return pygT.Compose(transforms)
def instantiate_datamodule_transforms(transform_options, log=None):
"""Create a dictionary of torch_geometric composite transforms from
a datamodule OmegaConf holding lists of transforms characterized by
a `*transform*` key such as:
```yaml
# parsed in the output dictionary
pre_transform:
- transform: GridSampling3D
params:
size: 0.01
- transform: NormaliseScale
# not parsed in the output dictionary
foo:
a: 1
b: 10
# parsed in the output dictionary
on_device_transform:
- transform: NodeSize
- transform: NAGAddSelfLoops
```
This helper function is typically intended for instantiating the
transforms of a `BaseDataModule` from an Omegaconf config object
Credit: https://github.com/torch-points3d/torch-points3d
"""
transforms_dict = {}
for key_name in transform_options.keys():
if "transform" not in key_name:
continue
name = key_name.replace("transforms", "transform")
params = getattr(transform_options, key_name, None)
if params is None:
continue
try:
transform = instantiate_transforms(params)
except Exception:
msg = f"Error trying to create {name}, {params}"
log.exception(msg) if log is not None else print(msg)
continue
transforms_dict[name] = transform
if len(transforms_dict) == 0:
msg = (f"Could not find any '*transform*' key among the provided config"
f" keys: {transform_options.keys()}. Are you sure you passed a "
f"datamodule config as input ?")
log.exception(msg) if log is not None else print(msg)
return transforms_dict
def explode_transform(transform_list):
"""Extracts a flattened list of transforms from a Compose or from a
list of transforms.
"""
out = []
if transform_list is not None:
if isinstance(transform_list, pygT.Compose):
out = copy.deepcopy(transform_list.transforms)
elif isinstance(transform_list, list):
out = copy.deepcopy(transform_list)
else:
raise Exception(
"Transforms should be provided either within a list or "
"a Compose")
return out