File size: 1,053 Bytes
ccef021 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 | import torch
def _get_new_non_contiguous_tensor_shape(shape):
"""
Get the expanded shape for a non-contiguous tensor.
The last dimension is increased by 128 (for alignment), and all other dimensions are increased by 1
"""
return [dim+128 if dim_idx == len(shape)-1 else dim+1 for dim_idx, dim in enumerate(shape)]
def gen_non_contiguous_randn_tensor(shape, *args, **kwargs):
new_shape = _get_new_non_contiguous_tensor_shape(shape)
base_tensor = torch.randn(new_shape, *args, **kwargs)
slices = [slice(0, dim) for dim in shape]
return base_tensor[slices]
def gen_non_contiguous_tensor(shape, *args, **kwargs):
new_shape = _get_new_non_contiguous_tensor_shape(shape)
base_tensor = torch.empty(new_shape, *args, **kwargs)
slices = [slice(0, dim) for dim in shape]
return base_tensor[slices]
def non_contiguousify(tensor: torch.Tensor) -> torch.Tensor:
new_tensor = gen_non_contiguous_tensor(tensor.shape, dtype=tensor.dtype, device=tensor.device)
new_tensor[:] = tensor
return new_tensor
|