File size: 501 Bytes
ab0f6ec | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | import torch
from collections.abc import Sequence, Mapping
def batch_apply(batch, fn):
if isinstance(batch, torch.Tensor):
return fn(batch)
elif isinstance(batch, Sequence):
return [batch_apply(x, fn) for x in batch]
elif isinstance(batch, Mapping):
return {x:batch_apply(batch[x], fn) for x in batch}
else:
raise NotImplementedError(f'Type of {type(batch)} are not supported in batch_apply')
def batch_to(batch, device):
return batch_apply(batch, lambda x: x.to(device))
|