| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| |
|
| | def pad_x_to_y(x, y, axis: int = -1): |
| | if axis != -1: |
| | raise NotImplementedError |
| | inp_len = y.shape[axis] |
| | output_len = x.shape[axis] |
| | return nn.functional.pad(x, [0, inp_len - output_len]) |
| |
|
| |
|
| | def shape_reconstructed(reconstructed, size): |
| | if len(size) == 1: |
| | return reconstructed.squeeze(0) |
| | return reconstructed |
| |
|
| |
|
| | def tensors_to_device(tensors, device): |
| | """Transfer tensor, dict or list of tensors to device. |
| | |
| | Args: |
| | tensors (:class:`torch.Tensor`): May be a single, a list or a |
| | dictionary of tensors. |
| | device (:class: `torch.device`): the device where to place the tensors. |
| | |
| | Returns: |
| | Union [:class:`torch.Tensor`, list, tuple, dict]: |
| | Same as input but transferred to device. |
| | Goes through lists and dicts and transfers the torch.Tensor to |
| | device. Leaves the rest untouched. |
| | """ |
| | if isinstance(tensors, torch.Tensor): |
| | return tensors.to(device) |
| | elif isinstance(tensors, (list, tuple)): |
| | return [tensors_to_device(tens, device) for tens in tensors] |
| | elif isinstance(tensors, dict): |
| | for key in tensors.keys(): |
| | tensors[key] = tensors_to_device(tensors[key], device) |
| | return tensors |
| | else: |
| | return tensors |
| |
|