deep-gemm / torch-ext /deep_gemm /__init__.py
medmekk's picture
Upload folder using huggingface_hub
c67ae40 verified
import os
import subprocess
import torch
from ._ops import ops
def _find_cuda_home():
cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
if cuda_home is None:
try:
with open(os.devnull, 'w') as devnull:
nvcc = subprocess.check_output(
['which', 'nvcc'], stderr=devnull
).decode().rstrip('\r\n')
cuda_home = os.path.dirname(os.path.dirname(nvcc))
except Exception:
cuda_home = '/usr/local/cuda'
if not os.path.exists(cuda_home):
cuda_home = ''
return cuda_home or ''
def _find_cutlass_include():
"""Find CUTLASS include path for JIT compilation of .cuh templates."""
# 1. Explicit env var
cutlass_include = os.environ.get('DG_CUTLASS_INCLUDE')
if cutlass_include and os.path.isdir(cutlass_include):
return cutlass_include
# 2. CUTLASS_HOME env var
cutlass_home = os.environ.get('CUTLASS_HOME')
if cutlass_home:
p = os.path.join(cutlass_home, 'include')
if os.path.isdir(os.path.join(p, 'cute')):
return p
# 3. Check in package include/ directory (bundled cute/cutlass headers)
pkg_include = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'include')
if os.path.isdir(os.path.join(pkg_include, 'cute')):
return pkg_include
# 4. Check CUDA_HOME/include (some CUDA 12.8+ installs include cute/)
cuda_home = _find_cuda_home()
if cuda_home:
cuda_inc = os.path.join(cuda_home, 'include')
if os.path.isdir(os.path.join(cuda_inc, 'cute')):
return cuda_inc
# 5. Try to find nvidia-cutlass Python package
try:
import cutlass as _cutlass
cutlass_dir = os.path.dirname(_cutlass.__file__)
p = os.path.join(cutlass_dir, 'include')
if os.path.isdir(os.path.join(p, 'cute')):
return p
except ImportError:
pass
# Return empty string; C++ side will also check env vars
return ""
def set_num_sms(new_num_sms):
ops.set_num_sms(new_num_sms)
def get_num_sms():
return ops.get_num_sms()
def set_tc_util(new_tc_util):
ops.set_tc_util(new_tc_util)
def get_tc_util():
return ops.get_tc_util()
# cuBLASLt GEMMs
def cublaslt_gemm_nt(a, b, d, c=None):
ops.cublaslt_gemm_nt(a, b, d, c)
def cublaslt_gemm_nn(a, b, d, c=None):
ops.cublaslt_gemm_nn(a, b, d, c)
def cublaslt_gemm_tn(a, b, d, c=None):
ops.cublaslt_gemm_tn(a, b, d, c)
def cublaslt_gemm_tt(a, b, d, c=None):
ops.cublaslt_gemm_tt(a, b, d, c)
try:
# FP8/FP4 GEMMs
def fp8_fp4_gemm_nt(a, b, d, c=None, recipe=None, recipe_a=None,
recipe_b=None, compiled_dims="nk", disable_ue8m0_cast=False):
ops.fp8_fp4_gemm_nt(a[0], a[1], b[0], b[1], d, c,
list(recipe) if recipe else None,
list(recipe_a) if recipe_a else None,
list(recipe_b) if recipe_b else None,
compiled_dims, disable_ue8m0_cast)
def fp8_fp4_gemm_nn(a, b, d, c=None, recipe=None, recipe_a=None,
recipe_b=None, compiled_dims="nk", disable_ue8m0_cast=False):
ops.fp8_fp4_gemm_nn(a[0], a[1], b[0], b[1], d, c,
list(recipe) if recipe else None,
list(recipe_a) if recipe_a else None,
list(recipe_b) if recipe_b else None,
compiled_dims, disable_ue8m0_cast)
def fp8_fp4_gemm_tn(a, b, d, c=None, recipe=None, recipe_a=None,
recipe_b=None, compiled_dims="mn", disable_ue8m0_cast=False):
ops.fp8_fp4_gemm_tn(a[0], a[1], b[0], b[1], d, c,
list(recipe) if recipe else None,
list(recipe_a) if recipe_a else None,
list(recipe_b) if recipe_b else None,
compiled_dims, disable_ue8m0_cast)
def fp8_fp4_gemm_tt(a, b, d, c=None, recipe=None, recipe_a=None,
recipe_b=None, compiled_dims="mn", disable_ue8m0_cast=False):
ops.fp8_fp4_gemm_tt(a[0], a[1], b[0], b[1], d, c,
list(recipe) if recipe else None,
list(recipe_a) if recipe_a else None,
list(recipe_b) if recipe_b else None,
compiled_dims, disable_ue8m0_cast)
fp8_gemm_nt = fp8_fp4_gemm_nt
fp8_gemm_nn = fp8_fp4_gemm_nn
fp8_gemm_tn = fp8_fp4_gemm_tn
fp8_gemm_tt = fp8_fp4_gemm_tt
def m_grouped_fp8_fp4_gemm_nt_contiguous(a, b, d, grouped_layout,
recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
disable_ue8m0_cast=False, use_psum_layout=False,
expected_m_for_psum_layout=None):
ops.m_grouped_fp8_fp4_gemm_nt_contiguous(
a[0], a[1], b[0], b[1], d, grouped_layout,
list(recipe) if recipe else None,
list(recipe_a) if recipe_a else None,
list(recipe_b) if recipe_b else None,
compiled_dims, disable_ue8m0_cast, use_psum_layout,
expected_m_for_psum_layout)
m_grouped_fp8_gemm_nt_contiguous = m_grouped_fp8_fp4_gemm_nt_contiguous
def m_grouped_fp8_fp4_gemm_nn_contiguous(a, b, d, grouped_layout,
recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
disable_ue8m0_cast=False, use_psum_layout=False):
ops.m_grouped_fp8_fp4_gemm_nn_contiguous(
a[0], a[1], b[0], b[1], d, grouped_layout,
list(recipe) if recipe else None,
list(recipe_a) if recipe_a else None,
list(recipe_b) if recipe_b else None,
compiled_dims, disable_ue8m0_cast, use_psum_layout)
m_grouped_fp8_gemm_nn_contiguous = m_grouped_fp8_fp4_gemm_nn_contiguous
def m_grouped_fp8_fp4_gemm_nt_masked(a, b, d, masked_m, expected_m,
recipe=None, recipe_a=None, recipe_b=None, compiled_dims="nk",
disable_ue8m0_cast=False):
ops.m_grouped_fp8_fp4_gemm_nt_masked(
a[0], a[1], b[0], b[1], d, masked_m, expected_m,
list(recipe) if recipe else None,
list(recipe_a) if recipe_a else None,
list(recipe_b) if recipe_b else None,
compiled_dims, disable_ue8m0_cast)
m_grouped_fp8_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked
def k_grouped_fp8_gemm_nt_contiguous(a, b, d, ks, ks_tensor, c=None,
recipe=(1, 1, 128), compiled_dims="mn"):
ops.k_grouped_fp8_gemm_nt_contiguous(
a[0], a[1], b[0], b[1], d, ks, ks_tensor, c,
list(recipe), compiled_dims)
def k_grouped_fp8_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c=None,
recipe=(1, 1, 128), compiled_dims="mn"):
ops.k_grouped_fp8_gemm_tn_contiguous(
a[0], a[1], b[0], b[1], d, ks, ks_tensor, c,
list(recipe), compiled_dims)
# BF16 GEMMs
def bf16_gemm_nt(a, b, d, c=None, compiled_dims="nk"):
ops.bf16_gemm_nt(a, b, d, c, compiled_dims)
def bf16_gemm_nn(a, b, d, c=None, compiled_dims="nk"):
ops.bf16_gemm_nn(a, b, d, c, compiled_dims)
def bf16_gemm_tn(a, b, d, c=None, compiled_dims="mn"):
ops.bf16_gemm_tn(a, b, d, c, compiled_dims)
def bf16_gemm_tt(a, b, d, c=None, compiled_dims="mn"):
ops.bf16_gemm_tt(a, b, d, c, compiled_dims)
def m_grouped_bf16_gemm_nt_contiguous(a, b, d, grouped_layout,
compiled_dims="nk", use_psum_layout=False,
expected_m_for_psum_layout=None):
ops.m_grouped_bf16_gemm_nt_contiguous(
a, b, d, grouped_layout, compiled_dims,
use_psum_layout, expected_m_for_psum_layout)
def m_grouped_bf16_gemm_nn_contiguous(a, b, d, grouped_layout,
compiled_dims="nk", use_psum_layout=False):
ops.m_grouped_bf16_gemm_nn_contiguous(
a, b, d, grouped_layout, compiled_dims, use_psum_layout)
def m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m,
compiled_dims="nk"):
ops.m_grouped_bf16_gemm_nt_masked(
a, b, d, masked_m, expected_m, compiled_dims)
def k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks, ks_tensor,
c=None, compiled_dims="mn"):
ops.k_grouped_bf16_gemm_tn_contiguous(
a, b, d, ks, ks_tensor, c, compiled_dims)
# Einsum
def einsum(expr, a, b, d, c=None, use_cublaslt=False):
ops.einsum(expr, a, b, d, c, use_cublaslt)
def fp8_einsum(expr, a, b, d, c=None, recipe=(1, 128, 128)):
ops.fp8_einsum(expr, a[0], a[1], b[0], b[1], d, c, list(recipe))
# Attention
def fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, recipe=None,
compiled_dims="nk", disable_ue8m0_cast=False):
ops.fp8_gemm_nt_skip_head_mid(
a[0], a[1], b[0], b[1], d, list(head_splits),
list(recipe) if recipe else None,
compiled_dims, disable_ue8m0_cast)
def fp8_mqa_logits(q, kv, weights, cu_seq_len_k_start,
cu_seq_len_k_end, clean_logits=True, max_seqlen_k=0):
return ops.fp8_mqa_logits(
q, kv[0], kv[1], weights,
cu_seq_len_k_start, cu_seq_len_k_end,
clean_logits, max_seqlen_k)
def get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms):
return ops.get_paged_mqa_logits_metadata(
context_lens, block_kv, num_sms)
def fp8_paged_mqa_logits(q, fused_kv_cache, weights, context_lens,
block_table, schedule_meta,
max_context_len, clean_logits=False):
return ops.fp8_paged_mqa_logits(
q, fused_kv_cache, weights, context_lens,
block_table, schedule_meta, max_context_len, clean_logits)
# Hyperconnection
def tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits=None):
ops.tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits)
# Layout
def transform_sf_into_required_layout(sf, mn, k, recipe=None,
recipe_ab=None, num_groups=None, is_sfa=False,
disable_ue8m0_cast=False):
return ops.transform_sf_into_required_layout(
sf, mn, k,
list(recipe) if recipe else None,
list(recipe_ab) if recipe_ab else None,
num_groups, is_sfa, disable_ue8m0_cast)
def get_mk_alignment_for_contiguous_layout():
return ops.get_mk_alignment_for_contiguous_layout()
# Legacy aliases
fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked
bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked
except Exception:
pass
# Utils
from . import utils
from .utils import *
# Testing
from . import testing
# Initialize (gracefully skip if CUDA is not available, e.g. in build sandboxes)
try:
ops.init(
os.path.dirname(os.path.abspath(__file__)),
_find_cuda_home(),
_find_cutlass_include()
)
except Exception:
pass
__version__ = '2.3.0'