| from typing import TYPE_CHECKING |
|
|
| import torch |
|
|
| |
| if TYPE_CHECKING: |
| def register_fake(fn): |
| return lambda name: fn |
| else: |
| try: |
| from torch.library import register_fake |
| except ImportError: |
| from torch.library import impl_abstract as register_fake |
|
|
| try: |
| from ._ops import ops, add_op_namespace_prefix |
| except ImportError as e: |
| |
| try: |
| import _quantization |
|
|
| ops = torch.ops._quantization |
|
|
| def add_op_namespace_prefix(op_name: str): |
| return f"_quantization::{op_name}" |
| except ImportError: |
| raise e |
|
|
|
|
| from .scalar_type import ScalarType |
|
|
|
|
| |
| def fp8_marlin_gemm( |
| a: torch.Tensor, |
| b_q_weight: torch.Tensor, |
| b_scales: torch.Tensor, |
| workspace: torch.Tensor, |
| num_bits: int, |
| size_m: int, |
| size_n: int, |
| size_k: int, |
| ) -> torch.Tensor: |
| return ops.fp8_marlin_gemm( |
| a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k |
| ) |
|
|
|
|
| |
| def gptq_marlin_gemm( |
| a: torch.Tensor, |
| b_q_weight: torch.Tensor, |
| b_scales: torch.Tensor, |
| b_zeros: torch.Tensor, |
| g_idx: torch.Tensor, |
| perm: torch.Tensor, |
| workspace: torch.Tensor, |
| b_q_type: ScalarType, |
| size_m: int, |
| size_n: int, |
| size_k: int, |
| is_k_full: bool, |
| has_zp: bool = False, |
| use_fp32_reduce: bool = False, |
| is_zp_float: bool = False, |
| ) -> torch.Tensor: |
| return ops.gptq_marlin_gemm( |
| a, |
| b_q_weight, |
| b_scales, |
| b_zeros, |
| g_idx, |
| perm, |
| workspace, |
| b_q_type.id, |
| size_m, |
| size_n, |
| size_k, |
| is_k_full, |
| has_zp, |
| use_fp32_reduce, |
| is_zp_float, |
| ) |
|
|
|
|
| |
| def gptq_marlin_repack( |
| b_q_weight: torch.Tensor, |
| perm: torch.Tensor, |
| size_k: int, |
| size_n: int, |
| num_bits: int, |
| ) -> torch.Tensor: |
| return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) |
|
|
|
|
| |
| def awq_marlin_repack( |
| b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int |
| ) -> torch.Tensor: |
| return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) |
|
|
|
|
| |
| def marlin_gemm( |
| a: torch.Tensor, |
| b_q_weight: torch.Tensor, |
| b_scales: torch.Tensor, |
| workspace: torch.Tensor, |
| size_m: int, |
| size_n: int, |
| size_k: int, |
| ) -> torch.Tensor: |
| return ops.marlin_gemm( |
| a, b_q_weight, b_scales, workspace, size_m, size_n, size_k |
| ) |
|
|
|
|
| |
| def gptq_marlin_24_gemm( |
| a: torch.Tensor, |
| b_q_weight: torch.Tensor, |
| b_meta: torch.Tensor, |
| b_scales: torch.Tensor, |
| workspace: torch.Tensor, |
| b_q_type: ScalarType, |
| size_m: int, |
| size_n: int, |
| size_k: int, |
| ) -> torch.Tensor: |
| return ops.gptq_marlin_24_gemm( |
| a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k |
| ) |
|
|
|
|
| |
| def marlin_qqq_gemm( |
| a: torch.Tensor, |
| b_q_weight: torch.Tensor, |
| s_tok: torch.Tensor, |
| s_ch: torch.Tensor, |
| s_group: torch.Tensor, |
| workspace: torch.Tensor, |
| size_m: int, |
| size_n: int, |
| size_k: int, |
| ) -> torch.Tensor: |
| return ops.marlin_qqq_gemm( |
| a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k |
| ) |
|
|
|
|
| |
|
|
| if hasattr(ops, "gptq_marlin_24_gemm"): |
| @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) |
| def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, |
| b_scales: torch.Tensor, workspace: torch.Tensor, |
| num_bits: int, size_m: torch.SymInt, |
| size_n: torch.SymInt, |
| size_k: torch.SymInt) -> torch.Tensor: |
| return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) |
|
|
| @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) |
| def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, |
| b_meta: torch.Tensor, b_scales: torch.Tensor, |
| workspace: torch.Tensor, |
| b_q_type: ScalarType, size_m: torch.SymInt, |
| size_n: torch.SymInt, |
| size_k: torch.SymInt) -> torch.Tensor: |
| return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) |
|
|
| @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) |
| def _gptq_marlin_gemm_fake(a: torch.Tensor, |
| b_q_weight: torch.Tensor, |
| b_scales: torch.Tensor, |
| b_zeros: torch.Tensor, |
| g_idx: torch.Tensor, |
| perm: torch.Tensor, |
| workspace: torch.Tensor, |
| b_q_type: ScalarType, |
| size_m: torch.SymInt, |
| size_n: torch.SymInt, |
| size_k: torch.SymInt, |
| is_k_full: bool, |
| has_zp: bool = False, |
| use_fp32_reduce: bool = False, |
| is_zp_float: bool = False) -> torch.Tensor: |
| return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) |
|
|
| @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) |
| def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, |
| s_tok: torch.Tensor, s_ch: torch.Tensor, |
| s_group: torch.Tensor, workspace: torch.Tensor, |
| size_m: torch.SymInt, size_n: torch.SymInt, |
| size_k: torch.SymInt) -> torch.Tensor: |
| return torch.empty((size_m, size_n), |
| dtype=torch.float16, |
| device=a.device) |
|
|
| @register_fake(add_op_namespace_prefix("marlin_gemm")) |
| def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, |
| b_scales: torch.Tensor, workspace: torch.Tensor, |
| size_m: torch.SymInt, size_n: torch.SymInt, |
| size_k: torch.SymInt) -> torch.Tensor: |
| return torch.empty((size_m, size_n), |
| dtype=torch.float16, |
| device=a.device) |
|
|