ALeLacheur's picture
Voiceblock demo: Attempt 8
957e2dc
import torch
import torch.nn as nn
from typing import OrderedDict, Dict, Any, TypeVar, Union
################################################################################
# Utilities for single/multi-GPU training
################################################################################
class DataParallelWrapper(nn.DataParallel):
"""Extend DataParallel class to allow full method/attribute access"""
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.module, name)
def state_dict(self,
*args,
destination=None,
prefix='',
keep_vars=False):
"""Avoid `module` prefix in saved weights"""
return self.module.state_dict(
destination=destination,
prefix=prefix,
keep_vars=keep_vars
)
def load_state_dict(self,
state_dict: OrderedDict[str, torch.Tensor],
strict: bool = True):
"""Avoid `module` prefix in saved weights"""
self.module.load_state_dict(state_dict, strict)
def get_cuda_device_ids():
"""Fetch all available CUDA devices"""
return list(range(torch.cuda.device_count()))
def wrap_module_multi_gpu(m: nn.Module, device_ids: list):
"""Implement data parallelism for arbitrary Module objects."""
if len(device_ids) < 1:
return m
elif isinstance(m, DataParallelWrapper):
return m
else:
return DataParallelWrapper(
module=m,
device_ids=device_ids
)
def unwrap_module_multi_gpu(m: nn.Module, device: Union[str, int, torch.device]):
if isinstance(m, DataParallelWrapper):
return m.module.to(device)
else:
return m.to(device)
def wrap_attack_multi_gpu(m: nn.Module, device_ids: list):
"""
Implement data parallelism for attack objects, including stored Pipeline
and Perturbation instances that may be accessed outside of `forward()`
"""
if len(device_ids) < 1:
return m
if hasattr(m, 'pipeline') and isinstance(m.pipeline, nn.Module):
m.pipeline = wrap_pipeline_multi_gpu(m.pipeline, device_ids)
if hasattr(m, 'perturbation') and isinstance(m.perturbation, nn.Module):
m.perturbation = wrap_module_multi_gpu(m.perturbation, device_ids)
# scale batch size to number of devices
if hasattr(m, 'batch_size'):
m.batch_size *= len(device_ids)
return m
def unwrap_attack_multi_gpu(m: nn.Module, device: Union[str, int, torch.device]):
"""
"""
if hasattr(m, 'pipeline') and isinstance(m.pipeline, DataParallelWrapper):
m.pipeline = unwrap_module_multi_gpu(m.pipeline, device)
if hasattr(m, 'perturbation') and isinstance(m.perturbation, DataParallelWrapper):
m.perturbation = unwrap_module_multi_gpu(m.perturbation, device)
# scale batch size to number of devices
if hasattr(m, 'batch_size'):
m.batch_size = m.batch_size // len(get_cuda_device_ids())
return m
def wrap_pipeline_multi_gpu(m: nn.Module, device_ids: list):
"""
Implement data parallelism for Pipeline objects, including all intermediate
stages that may be accessed outside of `forward()`
"""
if len(device_ids) < 1:
return m
return wrap_module_multi_gpu(m, device_ids)
def unwrap_pipeline_multi_gpu(m: nn.Module, device: Union[str, int, torch.device]):
"""
"""
return unwrap_module_multi_gpu(m, device)