flash-mla / tests /kernelkit /generate.py
medmekk's picture
Upload folder using huggingface_hub
ccef021 verified
import torch
def _get_new_non_contiguous_tensor_shape(shape):
"""
Get the expanded shape for a non-contiguous tensor.
The last dimension is increased by 128 (for alignment), and all other dimensions are increased by 1
"""
return [dim+128 if dim_idx == len(shape)-1 else dim+1 for dim_idx, dim in enumerate(shape)]
def gen_non_contiguous_randn_tensor(shape, *args, **kwargs):
new_shape = _get_new_non_contiguous_tensor_shape(shape)
base_tensor = torch.randn(new_shape, *args, **kwargs)
slices = [slice(0, dim) for dim in shape]
return base_tensor[slices]
def gen_non_contiguous_tensor(shape, *args, **kwargs):
new_shape = _get_new_non_contiguous_tensor_shape(shape)
base_tensor = torch.empty(new_shape, *args, **kwargs)
slices = [slice(0, dim) for dim in shape]
return base_tensor[slices]
def non_contiguousify(tensor: torch.Tensor) -> torch.Tensor:
new_tensor = gen_non_contiguous_tensor(tensor.shape, dtype=tensor.dtype, device=tensor.device)
new_tensor[:] = tensor
return new_tensor