| | |
| | |
| | |
| | |
| | |
| | |
| | import numpy as np |
| | import torch |
| |
|
| |
|
| | def todevice(batch, device, callback=None, non_blocking=False): |
| | ''' Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy). |
| | |
| | batch: list, tuple, dict of tensors or other things |
| | device: pytorch device or 'numpy' |
| | callback: function that would be called on every sub-elements. |
| | ''' |
| | if callback: |
| | batch = callback(batch) |
| |
|
| | if isinstance(batch, dict): |
| | return {k: todevice(v, device) for k, v in batch.items()} |
| |
|
| | if isinstance(batch, (tuple, list)): |
| | return type(batch)(todevice(x, device) for x in batch) |
| |
|
| | x = batch |
| | if device == 'numpy': |
| | if isinstance(x, torch.Tensor): |
| | x = x.detach().cpu().numpy() |
| | elif x is not None: |
| | if isinstance(x, np.ndarray): |
| | x = torch.from_numpy(x) |
| | if torch.is_tensor(x): |
| | x = x.to(device, non_blocking=non_blocking) |
| | return x |
| |
|
| |
|
| | to_device = todevice |
| |
|
| |
|
| | def to_numpy(x): return todevice(x, 'numpy') |
| | def to_cpu(x): return todevice(x, 'cpu') |
| | def to_cuda(x): return todevice(x, 'cuda') |
| |
|
| |
|
| | def collate_with_cat(whatever, lists=False): |
| | if isinstance(whatever, dict): |
| | return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()} |
| |
|
| | elif isinstance(whatever, (tuple, list)): |
| | if len(whatever) == 0: |
| | return whatever |
| | elem = whatever[0] |
| | T = type(whatever) |
| |
|
| | if elem is None: |
| | return None |
| | if isinstance(elem, (bool, float, int, str)): |
| | return whatever |
| | if isinstance(elem, tuple): |
| | return T(collate_with_cat(x, lists=lists) for x in zip(*whatever)) |
| | if isinstance(elem, dict): |
| | return {k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem} |
| |
|
| | if isinstance(elem, torch.Tensor): |
| | return listify(whatever) if lists else torch.cat(whatever) |
| | if isinstance(elem, np.ndarray): |
| | return listify(whatever) if lists else torch.cat([torch.from_numpy(x) for x in whatever]) |
| |
|
| | |
| | return sum(whatever, T()) |
| |
|
| |
|
| | def listify(elems): |
| | return [x for e in elems for x in e] |
| |
|