English
Shanci's picture
Upload folder using huggingface_hub
26225c5 verified
import torch
from src.transforms import Transform
from src.data import NAG
__all__ = ['DataTo', 'NAGTo']
class DataTo(Transform):
"""Move Data object to specified device."""
def __init__(self, device):
if not isinstance(device, torch.device):
device = torch.device(device)
self.device = device
def _process(self, data):
if data.device == self.device:
return data
return data.to(self.device)
class NAGTo(Transform):
"""Move Data object to specified device."""
_IN_TYPE = NAG
_OUT_TYPE = NAG
def __init__(self, device):
if not isinstance(device, torch.device):
device = torch.device(device)
self.device = device
def _process(self, nag):
if nag.device == self.device:
return nag
return nag.to(self.device)