| """ToTensor transformation.""" |
|
|
| import numpy as np |
| import torch |
|
|
| from vis4d.data.const import CommonKeys as K |
| from vis4d.data.typing import DictData |
|
|
| from .base import Transform |
|
|
|
|
| def _replace_arrays(data: DictData) -> None: |
| """Replace numpy arrays with tensors.""" |
| for key in data.keys(): |
| if key in [K.images, K.original_images]: |
| if not data[key].flags.c_contiguous: |
| data[key] = np.ascontiguousarray( |
| data[key].transpose(0, 3, 1, 2) |
| ) |
| data[key] = torch.from_numpy(data[key]) |
| else: |
| data[key] = ( |
| torch.from_numpy(data[key]) |
| .permute(0, 3, 1, 2) |
| .contiguous() |
| ) |
| elif isinstance(data[key], np.ndarray): |
| data[key] = torch.from_numpy(data[key]) |
| elif isinstance(data[key], dict): |
| _replace_arrays(data[key]) |
| elif isinstance(data[key], list): |
| for i, entry in enumerate(data[key]): |
| if isinstance(entry, np.ndarray): |
| data[key][i] = torch.from_numpy(entry) |
|
|
|
|
| @Transform("data", "data") |
| class ToTensor: |
| """Transform all entries in a list of DataDict from numpy to torch. |
| |
| Note that we reshape K.images from NHWC to NCHW. |
| """ |
|
|
| def __call__(self, batch: list[DictData]) -> list[DictData]: |
| """Transform all entries to tensor.""" |
| for data in batch: |
| _replace_arrays(data) |
| return batch |
|
|