#!/usr/bin/env python3 """ PyTorch 原生 FP8 权重量化 (v3 _scaled_mm 版) ============================================= v3: torch._scaled_mm 原生 FP8 matmul (SM89 硬件加速)。 关键: b 必须用 .t() (非 .T.contiguous()), 否则 CUBLAS_STATUS_NOT_SUPPORTED。 输入动态量化用 bf16 精度计算 amax,避免 float32 临时拷贝的显存开销。 用法: import native_fp8_patch native_fp8_patch.quantize_transformer_fp8(pipe.transformer) """ import torch import torch.nn as nn import torch.nn.functional as F class FP8Linear(nn.Module): """FP8 weight-only quantized Linear. v3: _scaled_mm + bf16 amax.""" def __init__(self, original_linear: nn.Linear, compute_dtype=torch.bfloat16): super().__init__() weight = original_linear.weight.data.float() amax = weight.abs().max() scale = (amax / 448.0).clamp(min=1e-12) fp8_weight = (weight / scale).to(torch.float8_e4m3fn) self.register_buffer("fp8_weight", fp8_weight) self.register_buffer("weight_scale", scale.view(())) if original_linear.bias is not None: self.register_buffer("bias", original_linear.bias.data.to(compute_dtype)) else: self.bias = None self.in_features = original_linear.in_features self.out_features = original_linear.out_features self.compute_dtype = compute_dtype del weight, fp8_weight @property def weight(self): return self.fp8_weight.to(self.compute_dtype) * self.weight_scale.to(self.compute_dtype) def forward(self, x): orig_shape = x.shape x_2d = x.reshape(-1, self.in_features) # bf16 精度算 amax — 避免 .float() 拷贝省显存 x_amax = x_2d.detach().abs().amax() x_scale = (x_amax.float() / 448.0).clamp(min=1e-12).view(()) x_fp8 = (x_2d / x_scale).to(torch.float8_e4m3fn) out = torch._scaled_mm( x_fp8, self.fp8_weight.t(), scale_a=x_scale, scale_b=self.weight_scale, out_dtype=self.compute_dtype, use_fast_accum=True, ) if self.bias is not None: out = out + self.bias return out.reshape(*orig_shape[:-1], self.out_features) def extra_repr(self): return "in=%d, out=%d, fp8+_scaled_mm" % (self.in_features, self.out_features) def quantize_transformer_fp8(transformer, compute_dtype=torch.bfloat16, verbose=True): """将 transformer 中所有 Linear 层的权重量化为 FP8。""" original_bytes = 0 quantized_bytes = 0 count = 0 linear_layers = [] for name, module in transformer.named_modules(): if isinstance(module, nn.Linear): linear_layers.append((name, module)) if verbose: print("[native_fp8] Quantizing %d Linear layers (_scaled_mm v3)..." % len(linear_layers)) for i, (name, module) in enumerate(linear_layers): orig_size = module.weight.numel() * module.weight.element_size() original_bytes += orig_size fp8_module = FP8Linear(module, compute_dtype=compute_dtype) quant_size = fp8_module.fp8_weight.numel() * 1 + 4 quantized_bytes += quant_size parts = name.split(".") parent = transformer for part in parts[:-1]: parent = getattr(parent, part) setattr(parent, parts[-1], fp8_module) del module count += 1 if verbose and (i + 1) % 200 == 0: print("[native_fp8] %d/%d layers done" % (i + 1, len(linear_layers))) import gc gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() stats = { "num_layers": count, "original_mb": original_bytes / 1024**2, "quantized_mb": quantized_bytes / 1024**2, "saved_mb": (original_bytes - quantized_bytes) / 1024**2, "compression": "%.1fx" % (original_bytes / max(quantized_bytes, 1)), } if verbose: print("[native_fp8] Done: %d layers, %.0f MB -> %.0f MB (saved %.0f MB, %.1fx)" % ( stats["num_layers"], stats["original_mb"], stats["quantized_mb"], stats["saved_mb"], original_bytes / max(quantized_bytes, 1) )) return stats def enable_fp8(infer_state_override=True): if infer_state_override: try: from hyvideo.commons.infer_state import get_infer_state state = get_infer_state() state.use_fp8_gemm = False print("[native_fp8] Disabled angelslim FP8 (use_fp8_gemm=False)") except Exception: pass if __name__ == "__main__": import time print("=" * 60) print("native_fp8_patch self-test (v3 _scaled_mm)") print("=" * 60) torch.manual_seed(42) device = "cuda" class MockTransformer(nn.Module): def __init__(self): super().__init__() self.linear1 = nn.Linear(3072, 3072, bias=False) self.linear2 = nn.Linear(3072, 12288, bias=True) self.linear3 = nn.Linear(12288, 3072, bias=True) def forward(self, x): return self.linear3(F.gelu(self.linear2(self.linear1(x)))) model = MockTransformer().to(torch.bfloat16).to(device) x = torch.randn(1, 100, 3072, dtype=torch.bfloat16, device=device) with torch.no_grad(): y_bf16 = model(x) stats = quantize_transformer_fp8(model) with torch.no_grad(): _ = model(x) torch.cuda.synchronize() t0 = time.time() for _ in range(200): y_fp8 = model(x) torch.cuda.synchronize() t = (time.time() - t0) / 200 * 1000 cos = F.cosine_similarity(y_bf16.reshape(-1, 3072), y_fp8.reshape(-1, 3072), dim=-1).mean().item() print(" Cosine: %.6f Time: %.2f ms %s" % (cos, t, "PASS" if cos > 0.995 else "WARN"))