File size: 2,345 Bytes
d38bce3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from abc import ABCMeta
import torch

class BaseAttack(object):
    """
    Attack base class.
    """

    __metaclass__ = ABCMeta

    def __init__(self, model, device = 'cuda'):
        self.model = model
        self.device = device

    def generate(self, image, label, **kwargs):
        """
        Overide this function for the main body of attack algorithm.

        Parameters
        ----------
        image :
            original image
        label :
            original label
        kwargs :
            user defined parameters
        """
        return input

    def parse_params(self, **kwargs):
        """
        Parse user defined parameters.
        """
        return True

    def check_type_device(self, image, label):
        """
        Check device, match variable type to device type.

        Parameters
        ----------
        image :
            image
        label :
            label
        """

        ################## devices
        if self.device == 'cuda':
            image = image.cuda()
            label = label.cuda()
            self.model = self.model.cuda()
        elif self.device == 'cpu':
            image = image.cpu()
            label = label.cpu()
            self.model = self.model.cpu()
        else:
            raise ValueError('Please input cpu or cuda')

        ################## data type
        if type(image).__name__ == 'Tensor':
            image = image.float()
            image = image.float().clone().detach().requires_grad_(True)
        elif type(image).__name__ == 'ndarray':
            image = image.astype('float')
            image = torch.tensor(image, requires_grad=True)
        else:
            raise ValueError('Input values only take numpy arrays or torch tensors')

        if type(label).__name__ == 'Tensor':
            label = label.long()
        elif type(label).__name__ == 'ndarray':
            label = label.astype('long')
            label = torch.tensor(y)
        else:
            raise ValueError('Input labels only take numpy arrays or torch tensors')


        #################### set init attributes
        self.image = image
        self.label = label

        return True

    def get_or_predict_lable(self, image):
        output = self.model(image)
        pred = output.argmax(dim=1, keepdim=True)
        return(pred)