File size: 5,813 Bytes
881f988 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | #!/usr/bin/env python3
"""
PyTorch 原生 FP8 权重量化 (v3 _scaled_mm 版)
=============================================
v3: torch._scaled_mm 原生 FP8 matmul (SM89 硬件加速)。
关键: b 必须用 .t() (非 .T.contiguous()), 否则 CUBLAS_STATUS_NOT_SUPPORTED。
输入动态量化用 bf16 精度计算 amax,避免 float32 临时拷贝的显存开销。
用法:
import native_fp8_patch
native_fp8_patch.quantize_transformer_fp8(pipe.transformer)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class FP8Linear(nn.Module):
"""FP8 weight-only quantized Linear. v3: _scaled_mm + bf16 amax."""
def __init__(self, original_linear: nn.Linear, compute_dtype=torch.bfloat16):
super().__init__()
weight = original_linear.weight.data.float()
amax = weight.abs().max()
scale = (amax / 448.0).clamp(min=1e-12)
fp8_weight = (weight / scale).to(torch.float8_e4m3fn)
self.register_buffer("fp8_weight", fp8_weight)
self.register_buffer("weight_scale", scale.view(()))
if original_linear.bias is not None:
self.register_buffer("bias", original_linear.bias.data.to(compute_dtype))
else:
self.bias = None
self.in_features = original_linear.in_features
self.out_features = original_linear.out_features
self.compute_dtype = compute_dtype
del weight, fp8_weight
@property
def weight(self):
return self.fp8_weight.to(self.compute_dtype) * self.weight_scale.to(self.compute_dtype)
def forward(self, x):
orig_shape = x.shape
x_2d = x.reshape(-1, self.in_features)
# bf16 精度算 amax — 避免 .float() 拷贝省显存
x_amax = x_2d.detach().abs().amax()
x_scale = (x_amax.float() / 448.0).clamp(min=1e-12).view(())
x_fp8 = (x_2d / x_scale).to(torch.float8_e4m3fn)
out = torch._scaled_mm(
x_fp8, self.fp8_weight.t(),
scale_a=x_scale, scale_b=self.weight_scale,
out_dtype=self.compute_dtype, use_fast_accum=True,
)
if self.bias is not None:
out = out + self.bias
return out.reshape(*orig_shape[:-1], self.out_features)
def extra_repr(self):
return "in=%d, out=%d, fp8+_scaled_mm" % (self.in_features, self.out_features)
def quantize_transformer_fp8(transformer, compute_dtype=torch.bfloat16, verbose=True):
"""将 transformer 中所有 Linear 层的权重量化为 FP8。"""
original_bytes = 0
quantized_bytes = 0
count = 0
linear_layers = []
for name, module in transformer.named_modules():
if isinstance(module, nn.Linear):
linear_layers.append((name, module))
if verbose:
print("[native_fp8] Quantizing %d Linear layers (_scaled_mm v3)..." % len(linear_layers))
for i, (name, module) in enumerate(linear_layers):
orig_size = module.weight.numel() * module.weight.element_size()
original_bytes += orig_size
fp8_module = FP8Linear(module, compute_dtype=compute_dtype)
quant_size = fp8_module.fp8_weight.numel() * 1 + 4
quantized_bytes += quant_size
parts = name.split(".")
parent = transformer
for part in parts[:-1]:
parent = getattr(parent, part)
setattr(parent, parts[-1], fp8_module)
del module
count += 1
if verbose and (i + 1) % 200 == 0:
print("[native_fp8] %d/%d layers done" % (i + 1, len(linear_layers)))
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
stats = {
"num_layers": count,
"original_mb": original_bytes / 1024**2,
"quantized_mb": quantized_bytes / 1024**2,
"saved_mb": (original_bytes - quantized_bytes) / 1024**2,
"compression": "%.1fx" % (original_bytes / max(quantized_bytes, 1)),
}
if verbose:
print("[native_fp8] Done: %d layers, %.0f MB -> %.0f MB (saved %.0f MB, %.1fx)" % (
stats["num_layers"], stats["original_mb"], stats["quantized_mb"],
stats["saved_mb"], original_bytes / max(quantized_bytes, 1)
))
return stats
def enable_fp8(infer_state_override=True):
if infer_state_override:
try:
from hyvideo.commons.infer_state import get_infer_state
state = get_infer_state()
state.use_fp8_gemm = False
print("[native_fp8] Disabled angelslim FP8 (use_fp8_gemm=False)")
except Exception:
pass
if __name__ == "__main__":
import time
print("=" * 60)
print("native_fp8_patch self-test (v3 _scaled_mm)")
print("=" * 60)
torch.manual_seed(42)
device = "cuda"
class MockTransformer(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(3072, 3072, bias=False)
self.linear2 = nn.Linear(3072, 12288, bias=True)
self.linear3 = nn.Linear(12288, 3072, bias=True)
def forward(self, x):
return self.linear3(F.gelu(self.linear2(self.linear1(x))))
model = MockTransformer().to(torch.bfloat16).to(device)
x = torch.randn(1, 100, 3072, dtype=torch.bfloat16, device=device)
with torch.no_grad():
y_bf16 = model(x)
stats = quantize_transformer_fp8(model)
with torch.no_grad():
_ = model(x)
torch.cuda.synchronize()
t0 = time.time()
for _ in range(200):
y_fp8 = model(x)
torch.cuda.synchronize()
t = (time.time() - t0) / 200 * 1000
cos = F.cosine_similarity(y_bf16.reshape(-1, 3072), y_fp8.reshape(-1, 3072), dim=-1).mean().item()
print(" Cosine: %.6f Time: %.2f ms %s" % (cos, t, "PASS" if cos > 0.995 else "WARN"))
|