|
|
|
|
|
|
|
|
from collections import defaultdict |
|
|
from dataclasses import fields, is_dataclass |
|
|
from typing import Any, Mapping, Protocol, runtime_checkable |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
def _is_named_tuple(x) -> bool: |
|
|
return isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields") |
|
|
|
|
|
|
|
|
@runtime_checkable |
|
|
class _CopyableData(Protocol): |
|
|
def to(self, device: torch.device, *args: Any, **kwargs: Any): |
|
|
"""Copy data to the specified device""" |
|
|
... |
|
|
|
|
|
|
|
|
def copy_data_to_device(data, device: torch.device, *args: Any, **kwargs: Any): |
|
|
"""Function that recursively copies data to a torch.device. |
|
|
|
|
|
Args: |
|
|
data: The data to copy to device |
|
|
device: The device to which the data should be copied |
|
|
args: positional arguments that will be passed to the `to` call |
|
|
kwargs: keyword arguments that will be passed to the `to` call |
|
|
|
|
|
Returns: |
|
|
The data on the correct device |
|
|
""" |
|
|
|
|
|
if _is_named_tuple(data): |
|
|
return type(data)( |
|
|
**copy_data_to_device(data._asdict(), device, *args, **kwargs) |
|
|
) |
|
|
elif isinstance(data, (list, tuple)): |
|
|
return type(data)(copy_data_to_device(e, device, *args, **kwargs) for e in data) |
|
|
elif isinstance(data, defaultdict): |
|
|
return type(data)( |
|
|
data.default_factory, |
|
|
{ |
|
|
k: copy_data_to_device(v, device, *args, **kwargs) |
|
|
for k, v in data.items() |
|
|
}, |
|
|
) |
|
|
elif isinstance(data, Mapping): |
|
|
return type(data)( |
|
|
{ |
|
|
k: copy_data_to_device(v, device, *args, **kwargs) |
|
|
for k, v in data.items() |
|
|
} |
|
|
) |
|
|
elif is_dataclass(data) and not isinstance(data, type): |
|
|
new_data_class = type(data)( |
|
|
**{ |
|
|
field.name: copy_data_to_device( |
|
|
getattr(data, field.name), device, *args, **kwargs |
|
|
) |
|
|
for field in fields(data) |
|
|
if field.init |
|
|
} |
|
|
) |
|
|
for field in fields(data): |
|
|
if not field.init: |
|
|
setattr( |
|
|
new_data_class, |
|
|
field.name, |
|
|
copy_data_to_device( |
|
|
getattr(data, field.name), device, *args, **kwargs |
|
|
), |
|
|
) |
|
|
return new_data_class |
|
|
elif isinstance(data, _CopyableData): |
|
|
return data.to(device, *args, **kwargs) |
|
|
return data |
|
|
|