Spaces:
Runtime error
Runtime error
| import copy | |
| import torch | |
| from torchvision.transforms import Compose | |
| _TRANSFORM_DICT = {} | |
| def register_transform(name): | |
| def decorator(cls): | |
| _TRANSFORM_DICT[name] = cls | |
| return cls | |
| return decorator | |
| def get_transform(cfg): | |
| if cfg is None or len(cfg) == 0: | |
| return None | |
| tfms = [] | |
| for t_dict in cfg: | |
| t_dict = copy.deepcopy(t_dict) | |
| cls = _TRANSFORM_DICT[t_dict.pop('type')] | |
| tfms.append(cls(**t_dict)) | |
| return Compose(tfms) | |
| def _index_select(v, index, n): | |
| if isinstance(v, torch.Tensor) and v.size(0) == n: | |
| return v[index] | |
| elif isinstance(v, list) and len(v) == n: | |
| return [v[i] for i in index] | |
| else: | |
| return v | |
| def _index_select_data(data, index): | |
| return { | |
| k: _index_select(v, index, data['aa'].size(0)) | |
| for k, v in data.items() | |
| } | |
| def _mask_select(v, mask): | |
| if isinstance(v, torch.Tensor) and v.size(0) == mask.size(0): | |
| return v[mask] | |
| elif isinstance(v, list) and len(v) == mask.size(0): | |
| return [v[i] for i, b in enumerate(mask) if b] | |
| else: | |
| return v | |
| def _mask_select_data(data, mask): | |
| return { | |
| k: _mask_select(v, mask) | |
| for k, v in data.items() | |
| } | |