File size: 3,495 Bytes
d02d576 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | import torch
from sgl_kernel.load_utils import _load_architecture_specific_ops, _preload_cuda_library
# Initialize the ops library based on current GPU
common_ops = _load_architecture_specific_ops()
# Preload the CUDA library to avoid the issue of libcudart.so.12 not found
if torch.version.cuda is not None:
_preload_cuda_library()
from sgl_kernel.allreduce import *
from sgl_kernel.attention import (
cutlass_mla_decode,
cutlass_mla_get_workspace_size,
merge_state,
merge_state_v2,
)
from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_data
from sgl_kernel.elementwise import (
FusedSetKVBufferArg,
apply_rope_with_cos_sin_cache_inplace,
concat_mla_absorb_q,
concat_mla_k,
copy_to_gpu_no_ce,
downcast_fp8,
fused_add_rmsnorm,
gelu_and_mul,
gelu_tanh_and_mul,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
rmsnorm,
rotary_embedding,
silu_and_mul,
)
from sgl_kernel.expert_specialization import (
es_fp8_blockwise_scaled_grouped_mm,
es_sm100_mxfp8_blockscaled_grouped_mm,
es_sm100_mxfp8_blockscaled_grouped_quant,
)
from sgl_kernel.gemm import (
awq_dequantize,
bmm_fp8,
cutlass_scaled_fp4_mm,
dsv3_fused_a_gemm,
dsv3_router_gemm,
fp8_blockwise_scaled_mm,
fp8_scaled_mm,
gptq_gemm,
gptq_shuffle,
int8_scaled_mm,
qserve_w4a8_per_chn_gemm,
qserve_w4a8_per_group_gemm,
scaled_fp4_grouped_quant,
scaled_fp4_quant,
sgl_per_tensor_quant_fp8,
sgl_per_token_group_quant_8bit,
sgl_per_token_group_quant_fp8,
sgl_per_token_group_quant_int8,
sgl_per_token_quant_fp8,
shuffle_rows,
silu_and_mul_scaled_fp4_grouped_quant,
)
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
from sgl_kernel.kvcacheio import (
transfer_kv_all_layer,
transfer_kv_all_layer_mla,
transfer_kv_per_layer,
transfer_kv_per_layer_mla,
)
from sgl_kernel.mamba import (
causal_conv1d_fn_cpu,
causal_conv1d_fwd,
causal_conv1d_update,
causal_conv1d_update_cpu,
chunk_gated_delta_rule_cpu,
)
from sgl_kernel.memory import set_kv_buffer_kernel, weak_ref_tensor
from sgl_kernel.moe import (
apply_shuffle_mul_sum,
fp8_blockwise_scaled_grouped_mm,
fused_qk_norm_rope,
kimi_k2_moe_fused_gate,
moe_align_block_size,
moe_fused_gate,
moe_sum,
moe_sum_reduce,
prepare_moe_input,
topk_sigmoid,
topk_softmax,
)
from sgl_kernel.quantization import (
ggml_dequantize,
ggml_moe_a8,
ggml_moe_a8_vec,
ggml_moe_get_block_size,
ggml_mul_mat_a8,
ggml_mul_mat_vec_a8,
)
from sgl_kernel.sampling import (
top_k_mask_logits,
top_k_renorm_prob,
top_p_renorm_prob,
)
from sgl_kernel.speculative import (
build_tree_kernel_efficient,
reconstruct_indices_from_tree_mask,
segment_packbits,
tree_speculative_sampling_target_only,
verify_tree_greedy,
)
from sgl_kernel.top_k import (
fast_topk,
fast_topk_transform_fused,
fast_topk_transform_ragged_fused,
fast_topk_v2,
)
from sgl_kernel.version import __version__
if torch.version.hip is not None:
from sgl_kernel.elementwise import gelu_quick
def create_greenctx_stream_by_value(*args, **kwargs):
from sgl_kernel.spatial import create_greenctx_stream_by_value as _impl
return _impl(*args, **kwargs)
def get_sm_available(*args, **kwargs):
from sgl_kernel.spatial import get_sm_available as _impl
return _impl(*args, **kwargs)
|