File size: 501 Bytes
ab0f6ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
from collections.abc import Sequence, Mapping

def batch_apply(batch, fn):
  if isinstance(batch, torch.Tensor):
    return fn(batch)
  elif isinstance(batch, Sequence):
    return [batch_apply(x, fn) for x in batch]
  elif isinstance(batch, Mapping):
    return {x:batch_apply(batch[x], fn) for x in batch}
  else:
    raise NotImplementedError(f'Type of {type(batch)} are not supported in batch_apply')

def batch_to(batch, device):
  return batch_apply(batch, lambda x: x.to(device))