Spaces:
Sleeping
Sleeping
File size: 3,711 Bytes
957e2dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import torch
import torch.nn as nn
from typing import OrderedDict, Dict, Any, TypeVar, Union
################################################################################
# Utilities for single/multi-GPU training
################################################################################
class DataParallelWrapper(nn.DataParallel):
"""Extend DataParallel class to allow full method/attribute access"""
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.module, name)
def state_dict(self,
*args,
destination=None,
prefix='',
keep_vars=False):
"""Avoid `module` prefix in saved weights"""
return self.module.state_dict(
destination=destination,
prefix=prefix,
keep_vars=keep_vars
)
def load_state_dict(self,
state_dict: OrderedDict[str, torch.Tensor],
strict: bool = True):
"""Avoid `module` prefix in saved weights"""
self.module.load_state_dict(state_dict, strict)
def get_cuda_device_ids():
"""Fetch all available CUDA devices"""
return list(range(torch.cuda.device_count()))
def wrap_module_multi_gpu(m: nn.Module, device_ids: list):
"""Implement data parallelism for arbitrary Module objects."""
if len(device_ids) < 1:
return m
elif isinstance(m, DataParallelWrapper):
return m
else:
return DataParallelWrapper(
module=m,
device_ids=device_ids
)
def unwrap_module_multi_gpu(m: nn.Module, device: Union[str, int, torch.device]):
if isinstance(m, DataParallelWrapper):
return m.module.to(device)
else:
return m.to(device)
def wrap_attack_multi_gpu(m: nn.Module, device_ids: list):
"""
Implement data parallelism for attack objects, including stored Pipeline
and Perturbation instances that may be accessed outside of `forward()`
"""
if len(device_ids) < 1:
return m
if hasattr(m, 'pipeline') and isinstance(m.pipeline, nn.Module):
m.pipeline = wrap_pipeline_multi_gpu(m.pipeline, device_ids)
if hasattr(m, 'perturbation') and isinstance(m.perturbation, nn.Module):
m.perturbation = wrap_module_multi_gpu(m.perturbation, device_ids)
# scale batch size to number of devices
if hasattr(m, 'batch_size'):
m.batch_size *= len(device_ids)
return m
def unwrap_attack_multi_gpu(m: nn.Module, device: Union[str, int, torch.device]):
"""
"""
if hasattr(m, 'pipeline') and isinstance(m.pipeline, DataParallelWrapper):
m.pipeline = unwrap_module_multi_gpu(m.pipeline, device)
if hasattr(m, 'perturbation') and isinstance(m.perturbation, DataParallelWrapper):
m.perturbation = unwrap_module_multi_gpu(m.perturbation, device)
# scale batch size to number of devices
if hasattr(m, 'batch_size'):
m.batch_size = m.batch_size // len(get_cuda_device_ids())
return m
def wrap_pipeline_multi_gpu(m: nn.Module, device_ids: list):
"""
Implement data parallelism for Pipeline objects, including all intermediate
stages that may be accessed outside of `forward()`
"""
if len(device_ids) < 1:
return m
return wrap_module_multi_gpu(m, device_ids)
def unwrap_pipeline_multi_gpu(m: nn.Module, device: Union[str, int, torch.device]):
"""
"""
return unwrap_module_multi_gpu(m, device)
|