| import torch |
|
|
| _num_sms = None |
|
|
|
|
| def set_num_sms(num_sms: int) -> None: |
| """ |
| Set the maximum SM count for all GEMM kernels to use. |
| |
| Arguments: |
| num_sms: the desired maximum SM count for all GEMM kernels to use. |
| """ |
| global _num_sms |
| assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count |
| _num_sms = num_sms |
|
|
|
|
| def get_num_sms() -> int: |
| """ |
| Get the current maximum limit of SM count for all GEMM kernels to use. |
| If the count is never specified, the function will return the number of device SMs. |
| |
| Returns: |
| Current maximum limit of SM count for all GEMM kernels to use. |
| """ |
| global _num_sms |
| if _num_sms is None: |
| _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count |
| return _num_sms |
|
|
|
|
| def ceil_div(x: int, y: int) -> int: |
| """ |
| Perform ceiling division of two integers. |
| |
| Args: |
| x: the dividend. |
| y: the divisor. |
| |
| Returns: |
| The result of the ceiling division. |
| """ |
| return (x + y - 1) // y |
|
|
|
|
| def get_m_alignment_for_contiguous_layout(): |
| """ |
| When we do a grouped GEMM in contiguous format, LHS are grouped into several batches along the M axis. |
| Since we deal with exactly one sub-matrix of RHS for each GEMM block, batch sizes above should align well |
| with GEMM block shape. |
| |
| Returns: |
| Group-level alignment requirement for grouped contiguous layout, which is always 128. |
| """ |
| return 128 |
|
|
|
|
| def get_tma_aligned_size(x: int, element_size: int) -> int: |
| """ |
| Global memory address of TMA must be 16-byte aligned. |
| Since we use column-major layout for the LHS scaling tensor, |
| the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes. |
| |
| Arguments: |
| x: original M-axis shape of the LHS scaling tensor. |
| element_size: element size of the LHS scaling tensor. |
| |
| Returns: |
| M-axis shape of the LHS scaling tensor after padding. |
| """ |
| tma_alignment_bytes = 16 |
| assert tma_alignment_bytes % element_size == 0 |
| alignment = tma_alignment_bytes // element_size |
| return ceil_div(x, alignment) * alignment |
|
|
|
|
| def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: |
| """ |
| Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary. |
| If the input tensor is already column-major layout and 16-byte aligned along the M axis |
| (thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing. |
| |
| Arguments: |
| x: usually the LHS scaling tensor in GEMM. |
| |
| Returns: |
| The LHS scaling tensor of TMA-aligned transposed format. |
| """ |
| |
| assert x.dim() in (2, 3) |
| remove_dim = False |
| if x.dim() == 2: |
| x, remove_dim = x.unsqueeze(0), True |
|
|
| b, m, n = x.shape |
| aligned_m = get_tma_aligned_size(m, x.element_size()) |
|
|
| |
| if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m: |
| return x.squeeze(0) if remove_dim else x |
|
|
| |
| aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) |
| aligned_x[:, :m, :] = x |
| aligned_x = aligned_x[:, :m, :] |
| return aligned_x.squeeze(0) if remove_dim else aligned_x |
|
|