HY-WorldPlay-FP8 / scripts /native_fp8_patch.py
vibegavin's picture
Initial release: FP8 quantized weights + turbo3 scripts + video demos
881f988 verified
#!/usr/bin/env python3
"""
PyTorch 原生 FP8 权重量化 (v3 _scaled_mm 版)
=============================================
v3: torch._scaled_mm 原生 FP8 matmul (SM89 硬件加速)。
关键: b 必须用 .t() (非 .T.contiguous()), 否则 CUBLAS_STATUS_NOT_SUPPORTED。
输入动态量化用 bf16 精度计算 amax,避免 float32 临时拷贝的显存开销。
用法:
import native_fp8_patch
native_fp8_patch.quantize_transformer_fp8(pipe.transformer)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class FP8Linear(nn.Module):
"""FP8 weight-only quantized Linear. v3: _scaled_mm + bf16 amax."""
def __init__(self, original_linear: nn.Linear, compute_dtype=torch.bfloat16):
super().__init__()
weight = original_linear.weight.data.float()
amax = weight.abs().max()
scale = (amax / 448.0).clamp(min=1e-12)
fp8_weight = (weight / scale).to(torch.float8_e4m3fn)
self.register_buffer("fp8_weight", fp8_weight)
self.register_buffer("weight_scale", scale.view(()))
if original_linear.bias is not None:
self.register_buffer("bias", original_linear.bias.data.to(compute_dtype))
else:
self.bias = None
self.in_features = original_linear.in_features
self.out_features = original_linear.out_features
self.compute_dtype = compute_dtype
del weight, fp8_weight
@property
def weight(self):
return self.fp8_weight.to(self.compute_dtype) * self.weight_scale.to(self.compute_dtype)
def forward(self, x):
orig_shape = x.shape
x_2d = x.reshape(-1, self.in_features)
# bf16 精度算 amax — 避免 .float() 拷贝省显存
x_amax = x_2d.detach().abs().amax()
x_scale = (x_amax.float() / 448.0).clamp(min=1e-12).view(())
x_fp8 = (x_2d / x_scale).to(torch.float8_e4m3fn)
out = torch._scaled_mm(
x_fp8, self.fp8_weight.t(),
scale_a=x_scale, scale_b=self.weight_scale,
out_dtype=self.compute_dtype, use_fast_accum=True,
)
if self.bias is not None:
out = out + self.bias
return out.reshape(*orig_shape[:-1], self.out_features)
def extra_repr(self):
return "in=%d, out=%d, fp8+_scaled_mm" % (self.in_features, self.out_features)
def quantize_transformer_fp8(transformer, compute_dtype=torch.bfloat16, verbose=True):
"""将 transformer 中所有 Linear 层的权重量化为 FP8。"""
original_bytes = 0
quantized_bytes = 0
count = 0
linear_layers = []
for name, module in transformer.named_modules():
if isinstance(module, nn.Linear):
linear_layers.append((name, module))
if verbose:
print("[native_fp8] Quantizing %d Linear layers (_scaled_mm v3)..." % len(linear_layers))
for i, (name, module) in enumerate(linear_layers):
orig_size = module.weight.numel() * module.weight.element_size()
original_bytes += orig_size
fp8_module = FP8Linear(module, compute_dtype=compute_dtype)
quant_size = fp8_module.fp8_weight.numel() * 1 + 4
quantized_bytes += quant_size
parts = name.split(".")
parent = transformer
for part in parts[:-1]:
parent = getattr(parent, part)
setattr(parent, parts[-1], fp8_module)
del module
count += 1
if verbose and (i + 1) % 200 == 0:
print("[native_fp8] %d/%d layers done" % (i + 1, len(linear_layers)))
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
stats = {
"num_layers": count,
"original_mb": original_bytes / 1024**2,
"quantized_mb": quantized_bytes / 1024**2,
"saved_mb": (original_bytes - quantized_bytes) / 1024**2,
"compression": "%.1fx" % (original_bytes / max(quantized_bytes, 1)),
}
if verbose:
print("[native_fp8] Done: %d layers, %.0f MB -> %.0f MB (saved %.0f MB, %.1fx)" % (
stats["num_layers"], stats["original_mb"], stats["quantized_mb"],
stats["saved_mb"], original_bytes / max(quantized_bytes, 1)
))
return stats
def enable_fp8(infer_state_override=True):
if infer_state_override:
try:
from hyvideo.commons.infer_state import get_infer_state
state = get_infer_state()
state.use_fp8_gemm = False
print("[native_fp8] Disabled angelslim FP8 (use_fp8_gemm=False)")
except Exception:
pass
if __name__ == "__main__":
import time
print("=" * 60)
print("native_fp8_patch self-test (v3 _scaled_mm)")
print("=" * 60)
torch.manual_seed(42)
device = "cuda"
class MockTransformer(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(3072, 3072, bias=False)
self.linear2 = nn.Linear(3072, 12288, bias=True)
self.linear3 = nn.Linear(12288, 3072, bias=True)
def forward(self, x):
return self.linear3(F.gelu(self.linear2(self.linear1(x))))
model = MockTransformer().to(torch.bfloat16).to(device)
x = torch.randn(1, 100, 3072, dtype=torch.bfloat16, device=device)
with torch.no_grad():
y_bf16 = model(x)
stats = quantize_transformer_fp8(model)
with torch.no_grad():
_ = model(x)
torch.cuda.synchronize()
t0 = time.time()
for _ in range(200):
y_fp8 = model(x)
torch.cuda.synchronize()
t = (time.time() - t0) / 200 * 1000
cos = F.cosine_similarity(y_bf16.reshape(-1, 3072), y_fp8.reshape(-1, 3072), dim=-1).mean().item()
print(" Cosine: %.6f Time: %.2f ms %s" % (cos, t, "PASS" if cos > 0.995 else "WARN"))