File size: 496 Bytes
fe6a903
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from ._ops import ops

def gemm_4bit_forward(
    input: torch.Tensor,
    weight: torch.Tensor,
    absmax: torch.Tensor,
    blocksize: int,
    quant_type: int,
) -> torch.Tensor:
    original_dtype = input.dtype
    if original_dtype != torch.bfloat16:
        input = input.to(torch.bfloat16)

    output = ops.gemm_4bit_forward(input, weight, absmax, blocksize, quant_type)
    if original_dtype != torch.bfloat16:
        output = output.to(original_dtype)

    return output