|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
QUANTS = [ |
|
|
None |
|
|
] |
|
|
|
|
|
|
|
|
try: |
|
|
from flashinfer import nvfp4_quantize, mm_fp4, SfLayout |
|
|
|
|
|
QUANTS.append("nvfp4") |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
|
|
|
@torch.library.custom_op("world_engine::fp4_linear", mutates_args=()) |
|
|
def fp4_linear( |
|
|
a_bf16: torch.Tensor, |
|
|
b_fp4_T: torch.Tensor, |
|
|
a_global_sf: torch.Tensor, |
|
|
b_sf_T: torch.Tensor, |
|
|
alpha: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
a_fp4, a_sf = nvfp4_quantize( |
|
|
a_bf16, |
|
|
a_global_sf, |
|
|
sfLayout=SfLayout.layout_128x4, |
|
|
do_shuffle=False, |
|
|
) |
|
|
return mm_fp4( |
|
|
a_fp4, b_fp4_T, a_sf, b_sf_T, alpha, out_dtype=torch.bfloat16, backend="cutlass" |
|
|
) |
|
|
|
|
|
|
|
|
@fp4_linear.register_fake |
|
|
def _fp4_linear_fake( |
|
|
a_bf16: torch.Tensor, |
|
|
b_fp4_T: torch.Tensor, |
|
|
a_global_sf: torch.Tensor, |
|
|
b_sf_T: torch.Tensor, |
|
|
alpha: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
return torch.empty( |
|
|
(a_bf16.shape[0], b_fp4_T.shape[1]), device=a_bf16.device, dtype=torch.bfloat16 |
|
|
) |
|
|
|
|
|
|
|
|
class FP4Linear(nn.Module): |
|
|
"""FP4 Linear layer using FlashInfer's NVFP4 quantization.""" |
|
|
|
|
|
def __init__(self, lin: nn.Linear): |
|
|
super().__init__() |
|
|
|
|
|
self.in_features = lin.in_features |
|
|
self.out_features = lin.out_features |
|
|
|
|
|
|
|
|
assert self.in_features % 32 == 0 and self.out_features % 32 == 0, ( |
|
|
"features % 32 != 0, nvfp4 disallowed" |
|
|
) |
|
|
|
|
|
|
|
|
self.weight = nn.Parameter(lin.weight.detach().clone()) |
|
|
|
|
|
|
|
|
self._weight_fp4_T: Optional[torch.Tensor] = None |
|
|
self._weight_scales_T: Optional[torch.Tensor] = None |
|
|
self._alpha: Optional[torch.Tensor] = None |
|
|
self._dummy_scale: Optional[torch.Tensor] = None |
|
|
self._weight_global_sf = None |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
self._dummy_scale = torch.full( |
|
|
(1,), 1.0, device=self.weight.device, dtype=torch.float32 |
|
|
) |
|
|
weight_bf16 = ( |
|
|
self.weight.to(torch.bfloat16).to(self.weight.device).contiguous() |
|
|
) |
|
|
weight_amax = weight_bf16.float().abs().nan_to_num().max() |
|
|
self._weight_global_sf = (1.0) / weight_amax |
|
|
self._alpha = 1.0 / (self._weight_global_sf * self._dummy_scale) |
|
|
w_fp4, w_sf = nvfp4_quantize( |
|
|
weight_bf16, |
|
|
self._weight_global_sf, |
|
|
sfLayout=SfLayout.layout_128x4, |
|
|
do_shuffle=False, |
|
|
) |
|
|
self._weight_fp4_T = w_fp4.t() |
|
|
self._weight_scales_T = w_sf.t() |
|
|
|
|
|
|
|
|
assert self.weight.is_cuda, "Weights need to be on GPU before quantization" |
|
|
|
|
|
lazy_x = torch.zeros( |
|
|
(1, lin.in_features), device=self.weight.device, dtype=torch.bfloat16 |
|
|
) |
|
|
fp4_linear( |
|
|
lazy_x, |
|
|
self._weight_fp4_T, |
|
|
self._dummy_scale, |
|
|
self._weight_scales_T, |
|
|
self._alpha, |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""Forward pass using FP4 quantization and FlashInfer GEMM.""" |
|
|
x_flat = x.reshape(-1, x.shape[-1]) |
|
|
y = fp4_linear( |
|
|
x_flat.to(torch.bfloat16).contiguous(), |
|
|
self._weight_fp4_T, |
|
|
self._dummy_scale, |
|
|
self._weight_scales_T, |
|
|
self._alpha, |
|
|
) |
|
|
return y.reshape(x.shape[:-1] + (-1,)) |
|
|
|
|
|
|
|
|
class FP8W8A8Linear(nn.Module): |
|
|
__constants__ = ("in_features", "out_features") |
|
|
|
|
|
def __init__(self, lin: nn.Linear): |
|
|
super().__init__() |
|
|
self.in_features, self.out_features = lin.in_features, lin.out_features |
|
|
|
|
|
f8 = torch.float8_e4m3fn |
|
|
inv = 1.0 / float(torch.finfo(f8).max) |
|
|
self._inv = inv |
|
|
|
|
|
w = lin.weight.detach() |
|
|
ws = (w.abs().amax() * inv).clamp_min(1e-8).float() |
|
|
wf8 = (w / ws.to(w.dtype)).to(f8).contiguous() |
|
|
self.register_buffer("wT", wf8.t()) |
|
|
self.register_buffer("ws", ws) |
|
|
|
|
|
if lin.bias is None: |
|
|
self.bias = None |
|
|
else: |
|
|
self.register_buffer("bias", lin.bias.detach().to(torch.float16)) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
s = x.shape |
|
|
x2 = x.reshape(-1, s[-1]) |
|
|
|
|
|
xs = (x2.abs().amax() * self._inv).clamp_min(1e-8).float() |
|
|
xf8 = (x2 / xs.to(x2.dtype)).to(torch.float8_e4m3fn).contiguous() |
|
|
|
|
|
y = torch._scaled_mm( |
|
|
xf8, |
|
|
self.wT, |
|
|
xs, |
|
|
self.ws, |
|
|
bias=self.bias, |
|
|
out_dtype=torch.float16, |
|
|
use_fast_accum=True, |
|
|
) |
|
|
return y.reshape(*s[:-1], self.out_features).to(x.dtype) |
|
|
|
|
|
|
|
|
class FP8Linear(nn.Module): |
|
|
def __init__(self, lin: nn.Linear): |
|
|
super().__init__() |
|
|
self.in_features, self.out_features = lin.in_features, lin.out_features |
|
|
|
|
|
self.bias = ( |
|
|
nn.Parameter(lin.bias.data.clone().to(torch.float8_e4m3fn)) |
|
|
if lin.bias is not None |
|
|
else None |
|
|
) |
|
|
w_amax = lin.weight.data.clone().amax().float().squeeze() |
|
|
w = lin.weight.data.clone().div(w_amax).to(torch.float8_e4m3fn) |
|
|
self.register_buffer("w_amax", w_amax) |
|
|
self.register_buffer("weightT", w.t()) |
|
|
self.dummy_scale = torch.ones((), device=lin.weight.device, dtype=torch.float32) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass using FP8 matmul. |
|
|
|
|
|
Args: |
|
|
x: Input tensor of shape [..., in_features] (flattens if > 2D) |
|
|
|
|
|
Returns: |
|
|
Output tensor of shape [..., out_features] in BF16 format, unflattened if input is > 2D |
|
|
""" |
|
|
|
|
|
|
|
|
x_fp8 = x.to(torch.float8_e4m3fn).reshape(-1, x.size(-1)).contiguous() |
|
|
|
|
|
result = torch._scaled_mm( |
|
|
x_fp8, |
|
|
self.weightT, |
|
|
bias=self.bias, |
|
|
scale_a=self.dummy_scale, |
|
|
scale_b=self.w_amax, |
|
|
out_dtype=torch.bfloat16, |
|
|
use_fast_accum=True, |
|
|
) |
|
|
|
|
|
return result.reshape(x.shape[:-1] + (-1,)) |
|
|
|
|
|
|
|
|
def quantize_model(model: nn.Module, quant: str): |
|
|
if quant is None: |
|
|
return model |
|
|
|
|
|
def eligible(m: nn.Module) -> bool: |
|
|
w = getattr(m, "weight", None) |
|
|
if not isinstance(m, nn.Linear): |
|
|
return False |
|
|
if getattr(w, "dtype", None) != torch.bfloat16: |
|
|
return False |
|
|
o, k = w.shape |
|
|
return (o % 32 == 0) and (k % 32 == 0) |
|
|
|
|
|
new_linear = { |
|
|
"w8a8": FP8W8A8Linear, |
|
|
"nvfp4": FP4Linear, |
|
|
"fp8": FP8Linear, |
|
|
}[quant] |
|
|
|
|
|
for name, child in model.named_children(): |
|
|
setattr(model, name, new_linear(child)) if eligible(child) else quantize_model( |
|
|
child, quant |
|
|
) |
|
|
return model |
|
|
|