File size: 858 Bytes
26225c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import torch
from src.transforms import Transform
from src.data import NAG
__all__ = ['DataTo', 'NAGTo']
class DataTo(Transform):
"""Move Data object to specified device."""
def __init__(self, device):
if not isinstance(device, torch.device):
device = torch.device(device)
self.device = device
def _process(self, data):
if data.device == self.device:
return data
return data.to(self.device)
class NAGTo(Transform):
"""Move Data object to specified device."""
_IN_TYPE = NAG
_OUT_TYPE = NAG
def __init__(self, device):
if not isinstance(device, torch.device):
device = torch.device(device)
self.device = device
def _process(self, nag):
if nag.device == self.device:
return nag
return nag.to(self.device)
|