| import os |
| import subprocess |
| import torch |
|
|
| from ._ops import ops |
|
|
|
|
| def _find_cuda_home(): |
| cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') |
| if cuda_home is None: |
| try: |
| with open(os.devnull, 'w') as devnull: |
| nvcc = subprocess.check_output( |
| ['which', 'nvcc'], stderr=devnull |
| ).decode().rstrip('\r\n') |
| cuda_home = os.path.dirname(os.path.dirname(nvcc)) |
| except Exception: |
| cuda_home = '/usr/local/cuda' |
| if not os.path.exists(cuda_home): |
| cuda_home = '' |
| return cuda_home or '' |
|
|
|
|
| def _find_cutlass_include(): |
| """Find CUTLASS include path for JIT compilation of .cuh templates.""" |
| |
| cutlass_include = os.environ.get('DG_CUTLASS_INCLUDE') |
| if cutlass_include and os.path.isdir(cutlass_include): |
| return cutlass_include |
|
|
| |
| cutlass_home = os.environ.get('CUTLASS_HOME') |
| if cutlass_home: |
| p = os.path.join(cutlass_home, 'include') |
| if os.path.isdir(os.path.join(p, 'cute')): |
| return p |
|
|
| |
| pkg_include = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'include') |
| if os.path.isdir(os.path.join(pkg_include, 'cute')): |
| return pkg_include |
|
|
| |
| cuda_home = _find_cuda_home() |
| if cuda_home: |
| cuda_inc = os.path.join(cuda_home, 'include') |
| if os.path.isdir(os.path.join(cuda_inc, 'cute')): |
| return cuda_inc |
|
|
| |
| try: |
| import cutlass as _cutlass |
| cutlass_dir = os.path.dirname(_cutlass.__file__) |
| p = os.path.join(cutlass_dir, 'include') |
| if os.path.isdir(os.path.join(p, 'cute')): |
| return p |
| except ImportError: |
| pass |
|
|
| |
| return "" |
|
|
|
|
| def set_num_sms(new_num_sms): |
| ops.set_num_sms(new_num_sms) |
|
|
| def get_num_sms(): |
| return ops.get_num_sms() |
|
|
| def set_tc_util(new_tc_util): |
| ops.set_tc_util(new_tc_util) |
|
|
| def get_tc_util(): |
| return ops.get_tc_util() |
|
|
|
|
| |
| def cublaslt_gemm_nt(a, b, d, c=None): |
| ops.cublaslt_gemm_nt(a, b, d, c) |
|
|
| def cublaslt_gemm_nn(a, b, d, c=None): |
| ops.cublaslt_gemm_nn(a, b, d, c) |
|
|
| def cublaslt_gemm_tn(a, b, d, c=None): |
| ops.cublaslt_gemm_tn(a, b, d, c) |
|
|
| def cublaslt_gemm_tt(a, b, d, c=None): |
| ops.cublaslt_gemm_tt(a, b, d, c) |
|
|
|
|
| try: |
| |
| def fp8_fp4_gemm_nt(a, b, d, c=None, recipe=None, recipe_a=None, |
| recipe_b=None, compiled_dims="nk", disable_ue8m0_cast=False): |
| ops.fp8_fp4_gemm_nt(a[0], a[1], b[0], b[1], d, c, |
| list(recipe) if recipe else None, |
| list(recipe_a) if recipe_a else None, |
| list(recipe_b) if recipe_b else None, |
| compiled_dims, disable_ue8m0_cast) |
|
|
| def fp8_fp4_gemm_nn(a, b, d, c=None, recipe=None, recipe_a=None, |
| recipe_b=None, compiled_dims="nk", disable_ue8m0_cast=False): |
| ops.fp8_fp4_gemm_nn(a[0], a[1], b[0], b[1], d, c, |
| list(recipe) if recipe else None, |
| list(recipe_a) if recipe_a else None, |
| list(recipe_b) if recipe_b else None, |
| compiled_dims, disable_ue8m0_cast) |
|
|
| def fp8_fp4_gemm_tn(a, b, d, c=None, recipe=None, recipe_a=None, |
| recipe_b=None, compiled_dims="mn", disable_ue8m0_cast=False): |
| ops.fp8_fp4_gemm_tn(a[0], a[1], b[0], b[1], d, c, |
| list(recipe) if recipe else None, |
| list(recipe_a) if recipe_a else None, |
| list(recipe_b) if recipe_b else None, |
| compiled_dims, disable_ue8m0_cast) |
|
|
| def fp8_fp4_gemm_tt(a, b, d, c=None, recipe=None, recipe_a=None, |
| recipe_b=None, compiled_dims="mn", disable_ue8m0_cast=False): |
| ops.fp8_fp4_gemm_tt(a[0], a[1], b[0], b[1], d, c, |
| list(recipe) if recipe else None, |
| list(recipe_a) if recipe_a else None, |
| list(recipe_b) if recipe_b else None, |
| compiled_dims, disable_ue8m0_cast) |
|
|
| fp8_gemm_nt = fp8_fp4_gemm_nt |
| fp8_gemm_nn = fp8_fp4_gemm_nn |
| fp8_gemm_tn = fp8_fp4_gemm_tn |
| fp8_gemm_tt = fp8_fp4_gemm_tt |
|
|
| def m_grouped_fp8_fp4_gemm_nt_contiguous(a, b, d, grouped_layout, |
| recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk", |
| disable_ue8m0_cast=False, use_psum_layout=False, |
| expected_m_for_psum_layout=None): |
| ops.m_grouped_fp8_fp4_gemm_nt_contiguous( |
| a[0], a[1], b[0], b[1], d, grouped_layout, |
| list(recipe) if recipe else None, |
| list(recipe_a) if recipe_a else None, |
| list(recipe_b) if recipe_b else None, |
| compiled_dims, disable_ue8m0_cast, use_psum_layout, |
| expected_m_for_psum_layout) |
|
|
| m_grouped_fp8_gemm_nt_contiguous = m_grouped_fp8_fp4_gemm_nt_contiguous |
|
|
| def m_grouped_fp8_fp4_gemm_nn_contiguous(a, b, d, grouped_layout, |
| recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk", |
| disable_ue8m0_cast=False, use_psum_layout=False): |
| ops.m_grouped_fp8_fp4_gemm_nn_contiguous( |
| a[0], a[1], b[0], b[1], d, grouped_layout, |
| list(recipe) if recipe else None, |
| list(recipe_a) if recipe_a else None, |
| list(recipe_b) if recipe_b else None, |
| compiled_dims, disable_ue8m0_cast, use_psum_layout) |
|
|
| m_grouped_fp8_gemm_nn_contiguous = m_grouped_fp8_fp4_gemm_nn_contiguous |
|
|
| def m_grouped_fp8_fp4_gemm_nt_masked(a, b, d, masked_m, expected_m, |
| recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk", |
| disable_ue8m0_cast=False): |
| ops.m_grouped_fp8_fp4_gemm_nt_masked( |
| a[0], a[1], b[0], b[1], d, masked_m, expected_m, |
| list(recipe) if recipe else None, |
| list(recipe_a) if recipe_a else None, |
| list(recipe_b) if recipe_b else None, |
| compiled_dims, disable_ue8m0_cast) |
|
|
| m_grouped_fp8_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked |
|
|
| def k_grouped_fp8_gemm_nt_contiguous(a, b, d, ks, ks_tensor, c=None, |
| recipe=(1, 1, 128), compiled_dims="mn"): |
| ops.k_grouped_fp8_gemm_nt_contiguous( |
| a[0], a[1], b[0], b[1], d, ks, ks_tensor, c, |
| list(recipe), compiled_dims) |
|
|
| def k_grouped_fp8_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c=None, |
| recipe=(1, 1, 128), compiled_dims="mn"): |
| ops.k_grouped_fp8_gemm_tn_contiguous( |
| a[0], a[1], b[0], b[1], d, ks, ks_tensor, c, |
| list(recipe), compiled_dims) |
|
|
| |
| def bf16_gemm_nt(a, b, d, c=None, compiled_dims="nk"): |
| ops.bf16_gemm_nt(a, b, d, c, compiled_dims) |
|
|
| def bf16_gemm_nn(a, b, d, c=None, compiled_dims="nk"): |
| ops.bf16_gemm_nn(a, b, d, c, compiled_dims) |
|
|
| def bf16_gemm_tn(a, b, d, c=None, compiled_dims="mn"): |
| ops.bf16_gemm_tn(a, b, d, c, compiled_dims) |
|
|
| def bf16_gemm_tt(a, b, d, c=None, compiled_dims="mn"): |
| ops.bf16_gemm_tt(a, b, d, c, compiled_dims) |
|
|
| def m_grouped_bf16_gemm_nt_contiguous(a, b, d, grouped_layout, |
| compiled_dims="nk", use_psum_layout=False, |
| expected_m_for_psum_layout=None): |
| ops.m_grouped_bf16_gemm_nt_contiguous( |
| a, b, d, grouped_layout, compiled_dims, |
| use_psum_layout, expected_m_for_psum_layout) |
|
|
| def m_grouped_bf16_gemm_nn_contiguous(a, b, d, grouped_layout, |
| compiled_dims="nk", use_psum_layout=False): |
| ops.m_grouped_bf16_gemm_nn_contiguous( |
| a, b, d, grouped_layout, compiled_dims, use_psum_layout) |
|
|
| def m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m, |
| compiled_dims="nk"): |
| ops.m_grouped_bf16_gemm_nt_masked( |
| a, b, d, masked_m, expected_m, compiled_dims) |
|
|
| def k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks, ks_tensor, |
| c=None, compiled_dims="mn"): |
| ops.k_grouped_bf16_gemm_tn_contiguous( |
| a, b, d, ks, ks_tensor, c, compiled_dims) |
|
|
| |
| def einsum(expr, a, b, d, c=None, use_cublaslt=False): |
| ops.einsum(expr, a, b, d, c, use_cublaslt) |
|
|
| def fp8_einsum(expr, a, b, d, c=None, recipe=(1, 128, 128)): |
| ops.fp8_einsum(expr, a[0], a[1], b[0], b[1], d, c, list(recipe)) |
|
|
| |
| def fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, recipe=None, |
| compiled_dims="nk", disable_ue8m0_cast=False): |
| ops.fp8_gemm_nt_skip_head_mid( |
| a[0], a[1], b[0], b[1], d, list(head_splits), |
| list(recipe) if recipe else None, |
| compiled_dims, disable_ue8m0_cast) |
|
|
| def fp8_mqa_logits(q, kv, weights, cu_seq_len_k_start, |
| cu_seq_len_k_end, clean_logits=True, max_seqlen_k=0): |
| return ops.fp8_mqa_logits( |
| q, kv[0], kv[1], weights, |
| cu_seq_len_k_start, cu_seq_len_k_end, |
| clean_logits, max_seqlen_k) |
|
|
| def get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms): |
| return ops.get_paged_mqa_logits_metadata( |
| context_lens, block_kv, num_sms) |
|
|
| def fp8_paged_mqa_logits(q, fused_kv_cache, weights, context_lens, |
| block_table, schedule_meta, |
| max_context_len, clean_logits=False): |
| return ops.fp8_paged_mqa_logits( |
| q, fused_kv_cache, weights, context_lens, |
| block_table, schedule_meta, max_context_len, clean_logits) |
|
|
| |
| def tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits=None): |
| ops.tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits) |
|
|
| |
| def transform_sf_into_required_layout(sf, mn, k, recipe=None, |
| recipe_ab=None, num_groups=None, is_sfa=False, |
| disable_ue8m0_cast=False): |
| return ops.transform_sf_into_required_layout( |
| sf, mn, k, |
| list(recipe) if recipe else None, |
| list(recipe_ab) if recipe_ab else None, |
| num_groups, is_sfa, disable_ue8m0_cast) |
|
|
| def get_mk_alignment_for_contiguous_layout(): |
| return ops.get_mk_alignment_for_contiguous_layout() |
|
|
| |
| fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked |
| bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked |
|
|
| except Exception: |
| pass |
|
|
| |
| from . import utils |
| from .utils import * |
|
|
| |
| from . import testing |
|
|
| |
| try: |
| ops.init( |
| os.path.dirname(os.path.abspath(__file__)), |
| _find_cuda_home(), |
| _find_cutlass_include() |
| ) |
| except Exception: |
| pass |
|
|
| __version__ = '2.3.0' |
|
|