| from typing import List, Tuple | |
| from torch.distributed._shard.sharding_spec import ( | |
| ShardMetadata, | |
| ) | |
| def _shards_get_overlap_region_wrt_saved_tensor( | |
| saved_shard: ShardMetadata, current_shard: ShardMetadata | |
| ) -> List[Tuple[int, int, int, int]]: | |
| """ | |
| Return the overlapping region between saved_shard and current_shard. | |
| There returned list has the same number of elements as the tensor's dimension. | |
| For each element, we produce a tuple with the following contents: | |
| (dimension, `saved_shard` offset, `current_shard` offset, length) | |
| Offsets are relative to each shard. | |
| """ | |
| narrows = [] | |
| for dim, ( | |
| saved_shard_offset, | |
| current_shard_offset, | |
| saved_shard_size, | |
| current_shard_size, | |
| ) in enumerate( | |
| zip( | |
| saved_shard.shard_offsets, | |
| current_shard.shard_offsets, | |
| saved_shard.shard_sizes, | |
| current_shard.shard_sizes, | |
| ) | |
| ): | |
| min_range_end = min( | |
| saved_shard_offset + saved_shard_size, | |
| current_shard_offset + current_shard_size, | |
| ) | |
| length = min_range_end - max(current_shard_offset, saved_shard_offset) | |
| if saved_shard_offset > current_shard_offset: | |
| offset_for_saved_tensor = 0 | |
| offset_for_current_tensor = saved_shard_offset - current_shard_offset | |
| else: | |
| offset_for_saved_tensor = current_shard_offset - saved_shard_offset | |
| offset_for_current_tensor = 0 | |
| narrows.append( | |
| (dim, offset_for_saved_tensor, offset_for_current_tensor, length) | |
| ) | |
| return narrows | |