| from typing import Optional |
|
|
| import torch |
|
|
| import quantization as ops |
|
|
| from .marlin_utils import marlin_make_workspace, marlin_permute_scales |
|
|
|
|
| def is_fp8_marlin_supported(): |
| capability = torch.cuda.get_device_capability() |
| capability = capability[0] * 10 + capability[1] |
| return capability >= 80 |
|
|
|
|
| def apply_fp8_marlin_linear( |
| input: torch.Tensor, |
| weight: torch.Tensor, |
| weight_scale: torch.Tensor, |
| workspace: torch.Tensor, |
| size_n: int, |
| size_k: int, |
| bias: Optional[torch.Tensor], |
| ) -> torch.Tensor: |
| |
| |
|
|
| reshaped_x = input.reshape(-1, input.shape[-1]) |
| out_shape = input.shape[:-1] + (size_n,) |
|
|
| output = ops.fp8_marlin_gemm( |
| a=reshaped_x, |
| b_q_weight=weight, |
| b_scales=weight_scale, |
| workspace=workspace, |
| num_bits=8, |
| size_m=reshaped_x.shape[0], |
| size_n=size_n, |
| size_k=size_k, |
| ) |
|
|
| if bias is not None: |
| output.add_(bias) |
|
|
| return output.reshape(out_shape) |
|
|
|
|
| def prepare_fp8_layer_for_marlin( |
| layer: torch.nn.Module, strategy: str = "tensor" |
| ) -> None: |
| part_size_n = layer.output_size_per_partition |
| part_size_k = layer.input_size_per_partition |
|
|
| device = layer.weight.device |
|
|
| |
| layer.workspace = marlin_make_workspace(part_size_n, device) |
|
|
| |
| |
| marlin_qweight = ops.gptq_marlin_repack( |
| b_q_weight=pack_fp8_to_int32(layer.weight), |
| perm=torch.empty(0, dtype=torch.int, device=device), |
| size_k=part_size_k, |
| size_n=part_size_n, |
| num_bits=8, |
| ) |
| layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) |
|
|
| |
| scales = layer.weight_scale.to(layer.orig_dtype) |
| |
| marlin_scales = marlin_permute_scales( |
| s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 |
| ) |
| layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) |
|
|
|
|
| def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: |
| """ |
| Repack FP8 weights to gptq format (packed int32 elements) |
| """ |
| assert fp8_tensor.dtype == torch.float8_e4m3fn |
| assert fp8_tensor.shape[0] % 4 == 0 |
|
|
| |
| reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) |
|
|
| |
| byte_tensor = reshaped.view(torch.uint8) |
|
|
| |
| packed = ( |
| byte_tensor[:, 0].to(torch.int32) |
| | (byte_tensor[:, 1].to(torch.int32) << 8) |
| | (byte_tensor[:, 2].to(torch.int32) << 16) |
| | (byte_tensor[:, 3].to(torch.int32) << 24) |
| ) |
|
|
| return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() |
|
|