import torch from src.transforms import Transform from src.data import NAG __all__ = ['DataTo', 'NAGTo'] class DataTo(Transform): """Move Data object to specified device.""" def __init__(self, device): if not isinstance(device, torch.device): device = torch.device(device) self.device = device def _process(self, data): if data.device == self.device: return data return data.to(self.device) class NAGTo(Transform): """Move Data object to specified device.""" _IN_TYPE = NAG _OUT_TYPE = NAG def __init__(self, device): if not isinstance(device, torch.device): device = torch.device(device) self.device = device def _process(self, nag): if nag.device == self.device: return nag return nag.to(self.device)