| |
|
|
| from collections import defaultdict |
| from dataclasses import fields, is_dataclass |
| from typing import Any, Mapping, Protocol, runtime_checkable |
|
|
| import torch |
|
|
|
|
| def _is_named_tuple(x) -> bool: |
| return isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields") |
|
|
|
|
| @runtime_checkable |
| class _CopyableData(Protocol): |
| def to(self, device: torch.device, *args: Any, **kwargs: Any): |
| """Copy data to the specified device""" |
| ... |
|
|
|
|
| def copy_data_to_device(data, device: torch.device, *args: Any, **kwargs: Any): |
| """Function that recursively copies data to a torch.device. |
| |
| Args: |
| data: The data to copy to device |
| device: The device to which the data should be copied |
| args: positional arguments that will be passed to the `to` call |
| kwargs: keyword arguments that will be passed to the `to` call |
| |
| Returns: |
| The data on the correct device |
| """ |
|
|
| if _is_named_tuple(data): |
| return type(data)( |
| **copy_data_to_device(data._asdict(), device, *args, **kwargs) |
| ) |
| elif isinstance(data, (list, tuple)): |
| return type(data)(copy_data_to_device(e, device, *args, **kwargs) for e in data) |
| elif isinstance(data, defaultdict): |
| return type(data)( |
| data.default_factory, |
| { |
| k: copy_data_to_device(v, device, *args, **kwargs) |
| for k, v in data.items() |
| }, |
| ) |
| elif isinstance(data, Mapping): |
| return type(data)( |
| { |
| k: copy_data_to_device(v, device, *args, **kwargs) |
| for k, v in data.items() |
| } |
| ) |
| elif is_dataclass(data) and not isinstance(data, type): |
| new_data_class = type(data)( |
| **{ |
| field.name: copy_data_to_device( |
| getattr(data, field.name), device, *args, **kwargs |
| ) |
| for field in fields(data) |
| if field.init |
| } |
| ) |
| for field in fields(data): |
| if not field.init: |
| setattr( |
| new_data_class, |
| field.name, |
| copy_data_to_device( |
| getattr(data, field.name), device, *args, **kwargs |
| ), |
| ) |
| return new_data_class |
| elif isinstance(data, _CopyableData): |
| return data.to(device, *args, **kwargs) |
| return data |
|
|