| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| | |
| |
|
| | import contextlib |
| |
|
| | import torch |
| | from torch import _C |
| | from torch.cuda import _lazy_call, device as device_ctx_manager |
| | from torch.utils.checkpoint import detach_variable |
| |
|
| | from megatron.memory import allocate_mem_buff |
| |
|
| | from .initialize import get_data_parallel_rank |
| | from .initialize import get_tensor_model_parallel_group |
| | from .initialize import get_tensor_model_parallel_rank |
| | from .initialize import get_tensor_model_parallel_world_size |
| |
|
| |
|
| | |
| | _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' |
| |
|
| |
|
| | def _set_cuda_rng_state(new_state, device=-1): |
| | """Sets the random number generator state of the current GPU. |
| | |
| | Argumentss: |
| | new_state (torch.ByteTensor): The desired state |
| | This function is adapted from PyTorch repo (torch.cuda.set_rng_state) |
| | with a single change: the input state is not cloned. Cloning caused |
| | major performance issues for +4 GPU cases. |
| | """ |
| | if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState): |
| | |
| | def cb(): |
| | with device_ctx_manager(device): |
| | _C._cuda_setRNGState(new_state) |
| | else: |
| | |
| | if device == -1: |
| | device = torch.device('cuda') |
| | elif isinstance(device, str): |
| | device = torch.device(device) |
| | elif isinstance(device, int): |
| | device = torch.device('cuda', device) |
| |
|
| | def cb(): |
| | idx = device.index |
| | if idx is None: |
| | idx = torch.cuda.current_device() |
| | default_generator = torch.cuda.default_generators[idx] |
| | default_generator.set_state(new_state) |
| |
|
| | _lazy_call(cb) |
| |
|
| |
|
| | def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): |
| | """Break a tensor into equal 1D chunks.""" |
| | partition_size = torch.numel(tensor) // \ |
| | get_tensor_model_parallel_world_size() |
| | start_index = partition_size * get_tensor_model_parallel_rank() |
| | end_index = start_index + partition_size |
| | if new_buffer: |
| | data = torch.empty(partition_size, dtype=tensor.dtype, |
| | device=torch.cuda.current_device(), |
| | requires_grad=False) |
| | data.copy_(tensor.view(-1)[start_index:end_index]) |
| | else: |
| | data = tensor.view(-1)[start_index:end_index] |
| | return data |
| | |
| |
|
| | def gather_split_1d_tensor(tensor): |
| | """Opposite of above function, gather values from model parallel ranks.""" |
| | numel_gathered = torch.numel(tensor) * \ |
| | get_tensor_model_parallel_world_size() |
| | gathered = torch.empty(numel_gathered, dtype=tensor.dtype, |
| | device=torch.cuda.current_device(), |
| | requires_grad=False) |
| | |
| | |
| | |
| | |
| | |
| | torch.distributed._all_gather_base(gathered, tensor, |
| | group=get_tensor_model_parallel_group()) |
| | return gathered |
| |
|
| |
|
| | def _kernel_make_viewless_tensor(inp, requires_grad): |
| | '''Make a viewless tensor. |
| | |
| | View tensors have the undesirable side-affect of retaining a reference |
| | to the originally-viewed tensor, even after manually setting the '.data' |
| | field. This method creates a new tensor that links to the old tensor's |
| | data, without linking the viewed tensor, referenced via the '._base' |
| | field. |
| | ''' |
| | out = torch.empty( |
| | (1,), |
| | dtype = inp.dtype, |
| | device = inp.device, |
| | requires_grad = requires_grad, |
| | ) |
| | out.data = inp.data |
| | return out |
| |
|
| | class MakeViewlessTensor(torch.autograd.Function): |
| | ''' |
| | Autograd function to make a viewless tensor. |
| | |
| | This function should be used in cases where the computation graph needs |
| | to be propagated, but we only want a viewless tensor (e.g., |
| | ParallelTransformer's hidden_states). Call this function by passing |
| | 'keep_graph = True' to 'make_viewless_tensor()'. |
| | ''' |
| | @staticmethod |
| | def forward(ctx, inp, requires_grad): |
| | return _kernel_make_viewless_tensor(inp, requires_grad) |
| | @staticmethod |
| | def backward(ctx, grad_output): |
| | return grad_output, None |
| |
|
| | def make_viewless_tensor(inp, requires_grad, keep_graph): |
| | ''' |
| | Entry-point for creating viewless tensors. |
| | |
| | This method should be used, rather than calling 'MakeViewlessTensor' |
| | or '_kernel_make_viewless_tensor' directly. This method acts as a |
| | switch for determining if an autograd function or a regular method |
| | should be used to create the tensor. |
| | ''' |
| |
|
| | |
| | if inp._base is None: |
| | return inp |
| |
|
| | |
| | if keep_graph: |
| | return MakeViewlessTensor.apply(inp, requires_grad) |
| | else: |
| | return _kernel_make_viewless_tensor(inp, requires_grad) |
| |
|
| | def assert_viewless_tensor(tensor, extra_msg = None): |
| | '''Assert that a tensor is not a view (i.e., its '._base' field is |
| | not set).''' |
| | if isinstance(tensor, list): |
| | [ assert_viewless_tensor(t) for t in tensor ] |
| | return tensor |
| | if not isinstance(tensor, torch.Tensor): |
| | return tensor |
| | assert tensor._base is None, ( |
| | "Ensure tensor._base is None before setting tensor.data or storing " |
| | "tensor to memory buffer. Otherwise, a memory leak will occur (and " |
| | "likely accumulate over iterations). %s" |
| | ) % extra_msg |
| | return tensor |
| |
|
| | def safely_set_viewless_tensor_data(tensor, new_data_tensor): |
| | '''Safely set tensor's '.data' field. |
| | |
| | Check first that the tensor is viewless (i.e., '._base' not set). If not, |
| | raise an exception. |
| | ''' |
| | assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape)) |
| | tensor.data = new_data_tensor |
| |
|
| |
|
| | class CudaRNGStatesTracker: |
| | """Tracker for the cuda RNG states. |
| | |
| | Using the `add` method, a cuda rng state is initialized based on |
| | the input `seed` and is assigned to `name`. Later, by forking the |
| | rng state, we can perform operations and return to our starting |
| | cuda state. |
| | """ |
| |
|
| | def __init__(self): |
| | |
| | self.states_ = {} |
| | |
| | self.seeds_ = set() |
| |
|
| | def reset(self): |
| | """Set to the initial state (no tracker).""" |
| | self.states_ = {} |
| | self.seeds_ = set() |
| |
|
| | def get_states(self): |
| | """Get rng states. Copy the dictionary so we have direct |
| | pointers to the states, not just a pointer to the dictionary.""" |
| | states = {} |
| | for name in self.states_: |
| | states[name] = self.states_[name] |
| | return states |
| |
|
| | def set_states(self, states): |
| | """Set the rng states. For efficiency purposes, we do not check |
| | the size of seed for compatibility.""" |
| | self.states_ = states |
| |
|
| | def add(self, name, seed): |
| | """Track the rng state.""" |
| | |
| | if seed in self.seeds_: |
| | raise Exception('seed {} already exists'.format(seed)) |
| | self.seeds_.add(seed) |
| | |
| | if name in self.states_: |
| | raise Exception('cuda rng state {} already exists'.format(name)) |
| | |
| | orig_rng_state = torch.cuda.get_rng_state() |
| | |
| | torch.cuda.manual_seed(seed) |
| | self.states_[name] = torch.cuda.get_rng_state() |
| | |
| | _set_cuda_rng_state(orig_rng_state) |
| |
|
| | @contextlib.contextmanager |
| | def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): |
| | """Fork the cuda rng state, perform operations, and exit with |
| | the original state.""" |
| | |
| | if name not in self.states_: |
| | raise Exception('cuda rng state {} is not added'.format(name)) |
| | |
| | orig_cuda_rng_state = torch.cuda.get_rng_state() |
| | |
| | _set_cuda_rng_state(self.states_[name]) |
| | |
| | try: |
| | yield |
| | finally: |
| | |
| | self.states_[name] = torch.cuda.get_rng_state() |
| | |
| | _set_cuda_rng_state(orig_cuda_rng_state) |
| |
|
| |
|
| | |
| | _CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() |
| |
|
| |
|
| | def get_cuda_rng_tracker(): |
| | """Get cuda rng tracker.""" |
| | return _CUDA_RNG_STATE_TRACKER |
| |
|
| |
|
| | def model_parallel_cuda_manual_seed(seed): |
| | """Initialize model parallel cuda seed. |
| | |
| | This function should be called after the model parallel is |
| | initialized. Also, no torch.cuda.manual_seed should be called |
| | after this function. Basically, this is replacement for that |
| | function. |
| | Two set of RNG states are tracked: |
| | default state: This is for data parallelism and is the same among a |
| | set of model parallel GPUs but different across |
| | different model paralle groups. This is used for |
| | example for dropout in the non-tensor-model-parallel regions. |
| | tensor-model-parallel state: This state is different among a set of model |
| | parallel GPUs, but the same across data parallel |
| | groups. This is used for example for dropout in |
| | model parallel regions. |
| | """ |
| | |
| | offset = seed + 2718 |
| | tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank() |
| | |
| | data_parallel_seed = seed |
| |
|
| | if torch.distributed.get_rank() == 0: |
| | print('> initializing model parallel cuda seeds on global rank {}, ' |
| | 'model parallel rank {}, and data parallel rank {} with ' |
| | 'model parallel seed: {} and data parallel seed: {}'.format( |
| | torch.distributed.get_rank(), get_tensor_model_parallel_rank(), |
| | get_data_parallel_rank(), tensor_model_parallel_seed, |
| | data_parallel_seed), flush=True) |
| | _CUDA_RNG_STATE_TRACKER.reset() |
| | |
| | torch.cuda.manual_seed(data_parallel_seed) |
| | |
| | _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, |
| | tensor_model_parallel_seed) |
| |
|
| |
|
| | class CheckpointFunction(torch.autograd.Function): |
| | """This function is adapted from torch.utils.checkpoint with |
| | two main changes: |
| | 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` |
| | 2) the states in the model parallel tracker are also properly |
| | tracked/set/reset. |
| | """ |
| | @staticmethod |
| | def forward(ctx, run_function, distribute_saved_activations, *args): |
| | ctx.run_function = run_function |
| | ctx.distribute_saved_activations \ |
| | = distribute_saved_activations |
| |
|
| | |
| | ctx.fwd_cpu_rng_state = torch.get_rng_state() |
| | ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() |
| | ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() |
| |
|
| | with torch.no_grad(): |
| | outputs = run_function(*args) |
| |
|
| | |
| | |
| | if distribute_saved_activations: |
| | ctx.input_0_shape = args[0].data.shape |
| | safely_set_viewless_tensor_data( |
| | args[0], |
| | split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True)) |
| |
|
| | |
| | ctx.save_for_backward(*args) |
| |
|
| | return outputs |
| |
|
| | @staticmethod |
| | def backward(ctx, *args): |
| | if not torch.autograd._is_checkpoint_valid(): |
| | raise RuntimeError("Checkpointing is not compatible with .grad(), " |
| | "please use .backward() if possible") |
| | inputs = ctx.saved_tensors |
| | if ctx.distribute_saved_activations: |
| | safely_set_viewless_tensor_data( |
| | inputs[0], |
| | gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape)) |
| |
|
| | |
| | bwd_cpu_rng_state = torch.get_rng_state() |
| | bwd_cuda_rng_state = torch.cuda.get_rng_state() |
| | bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() |
| |
|
| | |
| | torch.set_rng_state(ctx.fwd_cpu_rng_state) |
| | _set_cuda_rng_state(ctx.fwd_cuda_rng_state) |
| | get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) |
| |
|
| | |
| | detached_inputs = detach_variable(inputs) |
| | with torch.enable_grad(): |
| | outputs = ctx.run_function(*detached_inputs) |
| |
|
| | |
| | torch.set_rng_state(bwd_cpu_rng_state) |
| | _set_cuda_rng_state(bwd_cuda_rng_state) |
| | get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) |
| |
|
| | if isinstance(outputs, torch.Tensor): |
| | outputs = (outputs,) |
| | torch.autograd.backward(outputs, args) |
| | grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp |
| | for inp in detached_inputs) |
| | return (None, None) + grads |
| |
|
| |
|
| | def checkpoint(function, distribute_saved_activations, *args): |
| | """Checkpoint a model or part of the model. |
| | This has been directly copied from torch.utils.checkpoint.""" |
| | return CheckpointFunction.apply(function, |
| | distribute_saved_activations, *args) |
| |
|