| |
| """ |
| PyTorch 原生 FP8 权重量化 (v3 _scaled_mm 版) |
| ============================================= |
| v3: torch._scaled_mm 原生 FP8 matmul (SM89 硬件加速)。 |
| 关键: b 必须用 .t() (非 .T.contiguous()), 否则 CUBLAS_STATUS_NOT_SUPPORTED。 |
| 输入动态量化用 bf16 精度计算 amax,避免 float32 临时拷贝的显存开销。 |
| |
| 用法: |
| import native_fp8_patch |
| native_fp8_patch.quantize_transformer_fp8(pipe.transformer) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class FP8Linear(nn.Module): |
| """FP8 weight-only quantized Linear. v3: _scaled_mm + bf16 amax.""" |
|
|
| def __init__(self, original_linear: nn.Linear, compute_dtype=torch.bfloat16): |
| super().__init__() |
|
|
| weight = original_linear.weight.data.float() |
|
|
| amax = weight.abs().max() |
| scale = (amax / 448.0).clamp(min=1e-12) |
|
|
| fp8_weight = (weight / scale).to(torch.float8_e4m3fn) |
|
|
| self.register_buffer("fp8_weight", fp8_weight) |
| self.register_buffer("weight_scale", scale.view(())) |
|
|
| if original_linear.bias is not None: |
| self.register_buffer("bias", original_linear.bias.data.to(compute_dtype)) |
| else: |
| self.bias = None |
|
|
| self.in_features = original_linear.in_features |
| self.out_features = original_linear.out_features |
| self.compute_dtype = compute_dtype |
|
|
| del weight, fp8_weight |
|
|
| @property |
| def weight(self): |
| return self.fp8_weight.to(self.compute_dtype) * self.weight_scale.to(self.compute_dtype) |
|
|
| def forward(self, x): |
| orig_shape = x.shape |
| x_2d = x.reshape(-1, self.in_features) |
|
|
| |
| x_amax = x_2d.detach().abs().amax() |
| x_scale = (x_amax.float() / 448.0).clamp(min=1e-12).view(()) |
| x_fp8 = (x_2d / x_scale).to(torch.float8_e4m3fn) |
|
|
| out = torch._scaled_mm( |
| x_fp8, self.fp8_weight.t(), |
| scale_a=x_scale, scale_b=self.weight_scale, |
| out_dtype=self.compute_dtype, use_fast_accum=True, |
| ) |
|
|
| if self.bias is not None: |
| out = out + self.bias |
|
|
| return out.reshape(*orig_shape[:-1], self.out_features) |
|
|
| def extra_repr(self): |
| return "in=%d, out=%d, fp8+_scaled_mm" % (self.in_features, self.out_features) |
|
|
|
|
| def quantize_transformer_fp8(transformer, compute_dtype=torch.bfloat16, verbose=True): |
| """将 transformer 中所有 Linear 层的权重量化为 FP8。""" |
| original_bytes = 0 |
| quantized_bytes = 0 |
| count = 0 |
|
|
| linear_layers = [] |
| for name, module in transformer.named_modules(): |
| if isinstance(module, nn.Linear): |
| linear_layers.append((name, module)) |
|
|
| if verbose: |
| print("[native_fp8] Quantizing %d Linear layers (_scaled_mm v3)..." % len(linear_layers)) |
|
|
| for i, (name, module) in enumerate(linear_layers): |
| orig_size = module.weight.numel() * module.weight.element_size() |
| original_bytes += orig_size |
|
|
| fp8_module = FP8Linear(module, compute_dtype=compute_dtype) |
|
|
| quant_size = fp8_module.fp8_weight.numel() * 1 + 4 |
| quantized_bytes += quant_size |
|
|
| parts = name.split(".") |
| parent = transformer |
| for part in parts[:-1]: |
| parent = getattr(parent, part) |
| setattr(parent, parts[-1], fp8_module) |
|
|
| del module |
| count += 1 |
| if verbose and (i + 1) % 200 == 0: |
| print("[native_fp8] %d/%d layers done" % (i + 1, len(linear_layers))) |
|
|
| import gc |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| stats = { |
| "num_layers": count, |
| "original_mb": original_bytes / 1024**2, |
| "quantized_mb": quantized_bytes / 1024**2, |
| "saved_mb": (original_bytes - quantized_bytes) / 1024**2, |
| "compression": "%.1fx" % (original_bytes / max(quantized_bytes, 1)), |
| } |
|
|
| if verbose: |
| print("[native_fp8] Done: %d layers, %.0f MB -> %.0f MB (saved %.0f MB, %.1fx)" % ( |
| stats["num_layers"], stats["original_mb"], stats["quantized_mb"], |
| stats["saved_mb"], original_bytes / max(quantized_bytes, 1) |
| )) |
|
|
| return stats |
|
|
|
|
| def enable_fp8(infer_state_override=True): |
| if infer_state_override: |
| try: |
| from hyvideo.commons.infer_state import get_infer_state |
| state = get_infer_state() |
| state.use_fp8_gemm = False |
| print("[native_fp8] Disabled angelslim FP8 (use_fp8_gemm=False)") |
| except Exception: |
| pass |
|
|
|
|
| if __name__ == "__main__": |
| import time |
| print("=" * 60) |
| print("native_fp8_patch self-test (v3 _scaled_mm)") |
| print("=" * 60) |
|
|
| torch.manual_seed(42) |
| device = "cuda" |
|
|
| class MockTransformer(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear1 = nn.Linear(3072, 3072, bias=False) |
| self.linear2 = nn.Linear(3072, 12288, bias=True) |
| self.linear3 = nn.Linear(12288, 3072, bias=True) |
| def forward(self, x): |
| return self.linear3(F.gelu(self.linear2(self.linear1(x)))) |
|
|
| model = MockTransformer().to(torch.bfloat16).to(device) |
| x = torch.randn(1, 100, 3072, dtype=torch.bfloat16, device=device) |
|
|
| with torch.no_grad(): |
| y_bf16 = model(x) |
|
|
| stats = quantize_transformer_fp8(model) |
|
|
| with torch.no_grad(): |
| _ = model(x) |
| torch.cuda.synchronize() |
| t0 = time.time() |
| for _ in range(200): |
| y_fp8 = model(x) |
| torch.cuda.synchronize() |
| t = (time.time() - t0) / 200 * 1000 |
|
|
| cos = F.cosine_similarity(y_bf16.reshape(-1, 3072), y_fp8.reshape(-1, 3072), dim=-1).mean().item() |
| print(" Cosine: %.6f Time: %.2f ms %s" % (cos, t, "PASS" if cos > 0.995 else "WARN")) |
|
|