import torch def _get_new_non_contiguous_tensor_shape(shape): """ Get the expanded shape for a non-contiguous tensor. The last dimension is increased by 128 (for alignment), and all other dimensions are increased by 1 """ return [dim+128 if dim_idx == len(shape)-1 else dim+1 for dim_idx, dim in enumerate(shape)] def gen_non_contiguous_randn_tensor(shape, *args, **kwargs): new_shape = _get_new_non_contiguous_tensor_shape(shape) base_tensor = torch.randn(new_shape, *args, **kwargs) slices = [slice(0, dim) for dim in shape] return base_tensor[slices] def gen_non_contiguous_tensor(shape, *args, **kwargs): new_shape = _get_new_non_contiguous_tensor_shape(shape) base_tensor = torch.empty(new_shape, *args, **kwargs) slices = [slice(0, dim) for dim in shape] return base_tensor[slices] def non_contiguousify(tensor: torch.Tensor) -> torch.Tensor: new_tensor = gen_non_contiguous_tensor(tensor.shape, dtype=tensor.dtype, device=tensor.device) new_tensor[:] = tensor return new_tensor