| from typing import Optional | |
| import torch | |
| try: | |
| from petit_kernel import mul_nvfp4_a16, process_nvfp4_scales, repack_nvfp4 | |
| except ImportError: | |
| def _check_petit_nvfp4_supported( | |
| quant_method: str, group_size: Optional[int] | |
| ) -> tuple[bool, Optional[str]]: | |
| return ( | |
| False, | |
| "Petit is not installed. Please install it with `pip install petit-kernel`.", | |
| ) | |
| def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None: | |
| raise ValueError( | |
| "Petit is not installed. Please install it with `pip install petit-kernel`." | |
| ) | |
| def apply_petit_nvfp4_linear( | |
| input: torch.Tensor, | |
| weight: torch.Tensor, | |
| weight_scale: torch.Tensor, | |
| weight_scale_2: torch.Tensor, | |
| size_n: int, | |
| size_k: int, | |
| bias: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| raise ValueError( | |
| "Petit is not installed. Please install it with `pip install petit-kernel`." | |
| ) | |
| def _check_petit_nvfp4_supported( | |
| quant_method: str, group_size: Optional[int] | |
| ) -> tuple[bool, Optional[str]]: | |
| if quant_method != "NVFP4": | |
| return ( | |
| False, | |
| "Petit currently only supports: NVFP4" | |
| " quantizations in sglang. Please check the " | |
| "`hf_quant_config.json` file for your model's " | |
| "quant configuration.", | |
| ) | |
| if group_size is not None and group_size != 16: | |
| return ( | |
| False, | |
| "Petit currently only supports: group_size=16" " quantizations.", | |
| ) | |
| return (True, None) | |
| def verify_petit_nvfp4_supported(quant_method: str, group_size: Optional[int]) -> None: | |
| supported, error_msg = _check_petit_nvfp4_supported(quant_method, group_size) | |
| if not supported: | |
| raise ValueError(error_msg) | |
| def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None: | |
| # Repack weights to petit format | |
| part_size_n = layer.output_size_per_partition | |
| part_size_k = layer.input_size_per_partition | |
| qweight = layer.weight.view(torch.int32).contiguous() | |
| petit_qweight = repack_nvfp4(qweight, size_n=part_size_n, size_k=part_size_k) | |
| layer.weight = torch.nn.Parameter(petit_qweight, requires_grad=False) | |
| # Permute scales | |
| weight_scale = process_nvfp4_scales( | |
| scales=layer.weight_scale, size_k=part_size_k, size_n=part_size_n | |
| ) | |
| layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) | |
| return | |
| def apply_petit_nvfp4_linear( | |
| input: torch.Tensor, | |
| weight: torch.Tensor, | |
| weight_scale: torch.Tensor, | |
| weight_scale_2: torch.Tensor, | |
| size_n: int, | |
| size_k: int, | |
| bias: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| reshaped_x = input.reshape(-1, input.shape[-1]) | |
| out_shape = input.shape[:-1] + (size_n,) | |
| # TODO: Use auto-tuning to find the performant solution_id | |
| output = mul_nvfp4_a16( | |
| a=reshaped_x, | |
| b=weight, | |
| s=weight_scale, | |
| global_scale=weight_scale_2, | |
| size_m=reshaped_x.size(0), | |
| size_n=size_n, | |
| size_k=size_k, | |
| solution_id=-1, | |
| ) | |
| if bias is not None: | |
| output.add_(bias) # In-place add | |
| return output.reshape(out_shape) | |
Xet Storage Details
- Size:
- 3.25 kB
- Xet hash:
- dd87190642ab8d3a8ea9b7562497bccdee1fbaca35a972297507e0032c1e52db
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.