File size: 3,711 Bytes
957e2dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import torch.nn as nn

from typing import OrderedDict, Dict, Any, TypeVar, Union

################################################################################
# Utilities for single/multi-GPU training
################################################################################


class DataParallelWrapper(nn.DataParallel):
    """Extend DataParallel class to allow full method/attribute access"""
    def __getattr__(self, name):

        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.module, name)

    def state_dict(self,

                   *args,

                   destination=None,

                   prefix='',

                   keep_vars=False):
        """Avoid `module` prefix in saved weights"""
        return self.module.state_dict(
            destination=destination,
            prefix=prefix,
            keep_vars=keep_vars
        )

    def load_state_dict(self,

                        state_dict: OrderedDict[str, torch.Tensor],

                        strict: bool = True):
        """Avoid `module` prefix in saved weights"""
        self.module.load_state_dict(state_dict, strict)


def get_cuda_device_ids():
    """Fetch all available CUDA devices"""
    return list(range(torch.cuda.device_count()))


def wrap_module_multi_gpu(m: nn.Module, device_ids: list):
    """Implement data parallelism for arbitrary Module objects."""

    if len(device_ids) < 1:
        return m

    elif isinstance(m, DataParallelWrapper):
        return m

    else:
        return DataParallelWrapper(
            module=m,
            device_ids=device_ids
        )


def unwrap_module_multi_gpu(m: nn.Module, device: Union[str, int, torch.device]):

    if isinstance(m, DataParallelWrapper):
        return m.module.to(device)
    else:
        return m.to(device)


def wrap_attack_multi_gpu(m: nn.Module, device_ids: list):
    """

    Implement data parallelism for attack objects, including stored Pipeline

    and Perturbation instances that may be accessed outside of `forward()`

    """

    if len(device_ids) < 1:
        return m

    if hasattr(m, 'pipeline') and isinstance(m.pipeline, nn.Module):
        m.pipeline = wrap_pipeline_multi_gpu(m.pipeline, device_ids)

    if hasattr(m, 'perturbation') and isinstance(m.perturbation, nn.Module):
        m.perturbation = wrap_module_multi_gpu(m.perturbation, device_ids)

    # scale batch size to number of devices
    if hasattr(m, 'batch_size'):
        m.batch_size *= len(device_ids)

    return m


def unwrap_attack_multi_gpu(m: nn.Module, device: Union[str, int, torch.device]):
    """



    """
    if hasattr(m, 'pipeline') and isinstance(m.pipeline, DataParallelWrapper):
        m.pipeline = unwrap_module_multi_gpu(m.pipeline, device)

    if hasattr(m, 'perturbation') and isinstance(m.perturbation, DataParallelWrapper):
        m.perturbation = unwrap_module_multi_gpu(m.perturbation, device)

    # scale batch size to number of devices
    if hasattr(m, 'batch_size'):
        m.batch_size = m.batch_size // len(get_cuda_device_ids())

    return m


def wrap_pipeline_multi_gpu(m: nn.Module, device_ids: list):
    """

    Implement data parallelism for Pipeline objects, including all intermediate

    stages that may be accessed outside of `forward()`

    """

    if len(device_ids) < 1:
        return m

    return wrap_module_multi_gpu(m, device_ids)


def unwrap_pipeline_multi_gpu(m: nn.Module, device: Union[str, int, torch.device]):
    """

    """
    return unwrap_module_multi_gpu(m, device)