|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import fields, is_dataclass |
|
|
from typing import Any, Union |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
def move_data_to_device(inputs: Any, device: Union[str, torch.device], non_blocking: bool = True) -> Any: |
|
|
"""Recursively moves inputs to the specified device""" |
|
|
if inputs is None: |
|
|
return None |
|
|
if isinstance(inputs, torch.Tensor): |
|
|
return inputs.to(device, non_blocking=non_blocking) |
|
|
elif isinstance(inputs, (list, tuple, set)): |
|
|
return inputs.__class__([move_data_to_device(i, device, non_blocking) for i in inputs]) |
|
|
elif isinstance(inputs, dict): |
|
|
return {k: move_data_to_device(v, device, non_blocking) for k, v in inputs.items()} |
|
|
elif is_dataclass(inputs): |
|
|
return type(inputs)( |
|
|
**{ |
|
|
field.name: move_data_to_device(getattr(inputs, field.name), device, non_blocking) |
|
|
for field in fields(inputs) |
|
|
} |
|
|
) |
|
|
else: |
|
|
return inputs |
|
|
|