| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | import torch |
| |
|
| |
|
| | def ensure_divisibility(numerator, denominator): |
| | """Ensure that numerator is divisible by the denominator.""" |
| | assert numerator % denominator == 0, '{} is not divisible by {}'.format( |
| | numerator, denominator) |
| |
|
| |
|
| | def divide(numerator, denominator): |
| | """Ensure that numerator is divisible by the denominator and return |
| | the division value.""" |
| | ensure_divisibility(numerator, denominator) |
| | return numerator // denominator |
| |
|
| |
|
| | def split_tensor_along_last_dim(tensor, num_partitions, |
| | contiguous_split_chunks=False): |
| | """Split a tensor along its last dimension. |
| | Arguments: |
| | tensor: input tensor. |
| | num_partitions: number of partitions to split the tensor |
| | contiguous_split_chunks: If True, make each chunk contiguous |
| | in memory. |
| | """ |
| | |
| | last_dim = tensor.dim() - 1 |
| | last_dim_size = divide(tensor.size()[last_dim], num_partitions) |
| | |
| | tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) |
| | |
| | if contiguous_split_chunks: |
| | return tuple(chunk.contiguous() for chunk in tensor_list) |
| |
|
| | return tensor_list |
| |
|
| |
|
| | class VocabUtility: |
| | """Split the vocabulary into `world_size` chunks amd return the |
| | first and last index of the vocabulary belonging to the `rank` |
| | partition: Note that indecies in [fist, last)""" |
| |
|
| | @staticmethod |
| | def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, |
| | rank, world_size): |
| | index_f = rank * per_partition_vocab_size |
| | index_l = index_f + per_partition_vocab_size |
| | return index_f, index_l |
| |
|
| | @staticmethod |
| | def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): |
| | per_partition_vocab_size = divide(global_vocab_size, world_size) |
| | return VocabUtility.vocab_range_from_per_partition_vocab_size( |
| | per_partition_vocab_size, rank, world_size) |
| |
|