English
File size: 858 Bytes
26225c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from src.transforms import Transform
from src.data import NAG


__all__ = ['DataTo', 'NAGTo']


class DataTo(Transform):
    """Move Data object to specified device."""

    def __init__(self, device):
        if not isinstance(device, torch.device):
            device = torch.device(device)
        self.device = device

    def _process(self, data):
        if data.device == self.device:
            return data
        return data.to(self.device)


class NAGTo(Transform):
    """Move Data object to specified device."""

    _IN_TYPE = NAG
    _OUT_TYPE = NAG

    def __init__(self, device):
        if not isinstance(device, torch.device):
            device = torch.device(device)
        self.device = device

    def _process(self, nag):
        if nag.device == self.device:
            return nag
        return nag.to(self.device)