drbh
Migrated from kernels-community/quantization-bitsandbytes
bd05303 unverified
raw
history blame
496 Bytes
import torch
from ._ops import ops
def gemm_4bit_forward(
input: torch.Tensor,
weight: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: int,
) -> torch.Tensor:
original_dtype = input.dtype
if original_dtype != torch.bfloat16:
input = input.to(torch.bfloat16)
output = ops.gemm_4bit_forward(input, weight, absmax, blocksize, quant_type)
if original_dtype != torch.bfloat16:
output = output.to(original_dtype)
return output