| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from dataclasses import dataclass, fields |
|
|
|
|
| class Transform: |
|
|
| def collate(self, lst_datastruct): |
| from ..tools import collate_tensor_with_padding |
| example = lst_datastruct[0] |
|
|
| def collate_or_none(key): |
| if example[key] is None: |
| return None |
| key_lst = [x[key] for x in lst_datastruct] |
| return collate_tensor_with_padding(key_lst) |
|
|
| kwargs = {key: collate_or_none(key) for key in example.datakeys} |
|
|
| return self.Datastruct(**kwargs) |
|
|
|
|
| |
| |
| @dataclass |
| class Datastruct: |
|
|
| def __getitem__(self, key): |
| return getattr(self, key) |
|
|
| def __setitem__(self, key, value): |
| self.__dict__[key] = value |
|
|
| def get(self, key, default=None): |
| return getattr(self, key, default) |
|
|
| def __iter__(self): |
| return self.keys() |
|
|
| def keys(self): |
| keys = [t.name for t in fields(self)] |
| return iter(keys) |
|
|
| def values(self): |
| values = [getattr(self, t.name) for t in fields(self)] |
| return iter(values) |
|
|
| def items(self): |
| data = [(t.name, getattr(self, t.name)) for t in fields(self)] |
| return iter(data) |
|
|
| def to(self, *args, **kwargs): |
| for key in self.datakeys: |
| if self[key] is not None: |
| self[key] = self[key].to(*args, **kwargs) |
| return self |
|
|
| @property |
| def device(self): |
| return self[self.datakeys[0]].device |
|
|
| def detach(self): |
|
|
| def detach_or_none(tensor): |
| if tensor is not None: |
| return tensor.detach() |
| return None |
|
|
| kwargs = {key: detach_or_none(self[key]) for key in self.datakeys} |
| return self.transforms.Datastruct(**kwargs) |
|
|