| import torch | |
| from collections.abc import Sequence, Mapping | |
| def batch_apply(batch, fn): | |
| if isinstance(batch, torch.Tensor): | |
| return fn(batch) | |
| elif isinstance(batch, Sequence): | |
| return [batch_apply(x, fn) for x in batch] | |
| elif isinstance(batch, Mapping): | |
| return {x:batch_apply(batch[x], fn) for x in batch} | |
| else: | |
| raise NotImplementedError(f'Type of {type(batch)} are not supported in batch_apply') | |
| def batch_to(batch, device): | |
| return batch_apply(batch, lambda x: x.to(device)) | |