Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from typing import OrderedDict, Dict, Any, TypeVar, Union | |
| ################################################################################ | |
| # Utilities for single/multi-GPU training | |
| ################################################################################ | |
| class DataParallelWrapper(nn.DataParallel): | |
| """Extend DataParallel class to allow full method/attribute access""" | |
| def __getattr__(self, name): | |
| try: | |
| return super().__getattr__(name) | |
| except AttributeError: | |
| return getattr(self.module, name) | |
| def state_dict(self, | |
| *args, | |
| destination=None, | |
| prefix='', | |
| keep_vars=False): | |
| """Avoid `module` prefix in saved weights""" | |
| return self.module.state_dict( | |
| destination=destination, | |
| prefix=prefix, | |
| keep_vars=keep_vars | |
| ) | |
| def load_state_dict(self, | |
| state_dict: OrderedDict[str, torch.Tensor], | |
| strict: bool = True): | |
| """Avoid `module` prefix in saved weights""" | |
| self.module.load_state_dict(state_dict, strict) | |
| def get_cuda_device_ids(): | |
| """Fetch all available CUDA devices""" | |
| return list(range(torch.cuda.device_count())) | |
| def wrap_module_multi_gpu(m: nn.Module, device_ids: list): | |
| """Implement data parallelism for arbitrary Module objects.""" | |
| if len(device_ids) < 1: | |
| return m | |
| elif isinstance(m, DataParallelWrapper): | |
| return m | |
| else: | |
| return DataParallelWrapper( | |
| module=m, | |
| device_ids=device_ids | |
| ) | |
| def unwrap_module_multi_gpu(m: nn.Module, device: Union[str, int, torch.device]): | |
| if isinstance(m, DataParallelWrapper): | |
| return m.module.to(device) | |
| else: | |
| return m.to(device) | |
| def wrap_attack_multi_gpu(m: nn.Module, device_ids: list): | |
| """ | |
| Implement data parallelism for attack objects, including stored Pipeline | |
| and Perturbation instances that may be accessed outside of `forward()` | |
| """ | |
| if len(device_ids) < 1: | |
| return m | |
| if hasattr(m, 'pipeline') and isinstance(m.pipeline, nn.Module): | |
| m.pipeline = wrap_pipeline_multi_gpu(m.pipeline, device_ids) | |
| if hasattr(m, 'perturbation') and isinstance(m.perturbation, nn.Module): | |
| m.perturbation = wrap_module_multi_gpu(m.perturbation, device_ids) | |
| # scale batch size to number of devices | |
| if hasattr(m, 'batch_size'): | |
| m.batch_size *= len(device_ids) | |
| return m | |
| def unwrap_attack_multi_gpu(m: nn.Module, device: Union[str, int, torch.device]): | |
| """ | |
| """ | |
| if hasattr(m, 'pipeline') and isinstance(m.pipeline, DataParallelWrapper): | |
| m.pipeline = unwrap_module_multi_gpu(m.pipeline, device) | |
| if hasattr(m, 'perturbation') and isinstance(m.perturbation, DataParallelWrapper): | |
| m.perturbation = unwrap_module_multi_gpu(m.perturbation, device) | |
| # scale batch size to number of devices | |
| if hasattr(m, 'batch_size'): | |
| m.batch_size = m.batch_size // len(get_cuda_device_ids()) | |
| return m | |
| def wrap_pipeline_multi_gpu(m: nn.Module, device_ids: list): | |
| """ | |
| Implement data parallelism for Pipeline objects, including all intermediate | |
| stages that may be accessed outside of `forward()` | |
| """ | |
| if len(device_ids) < 1: | |
| return m | |
| return wrap_module_multi_gpu(m, device_ids) | |
| def unwrap_pipeline_multi_gpu(m: nn.Module, device: Union[str, int, torch.device]): | |
| """ | |
| """ | |
| return unwrap_module_multi_gpu(m, device) | |