File size: 22,510 Bytes
0707b22 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 | import os
from typing import List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
from tqdm import tqdm
from musubi_tuner.utils.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
from musubi_tuner.utils.device_utils import clean_memory_on_device
def calculate_fp8_maxval(exp_bits=4, mantissa_bits=3, sign_bits=1):
"""
Calculate the maximum representable value in FP8 format.
Default is E4M3 format (4-bit exponent, 3-bit mantissa, 1-bit sign). Only supports E4M3 and E5M2 with sign bit.
Args:
exp_bits (int): Number of exponent bits
mantissa_bits (int): Number of mantissa bits
sign_bits (int): Number of sign bits (0 or 1)
Returns:
float: Maximum value representable in FP8 format
"""
assert exp_bits + mantissa_bits + sign_bits == 8, "Total bits must be 8"
if exp_bits == 4 and mantissa_bits == 3 and sign_bits == 1:
return torch.finfo(torch.float8_e4m3fn).max
elif exp_bits == 5 and mantissa_bits == 2 and sign_bits == 1:
return torch.finfo(torch.float8_e5m2).max
else:
raise ValueError(f"Unsupported FP8 format: E{exp_bits}M{mantissa_bits} with sign_bits={sign_bits}")
# The following is a manual calculation method (wrong implementation for E5M2), kept for reference.
"""
# Calculate exponent bias
bias = 2 ** (exp_bits - 1) - 1
# Calculate maximum mantissa value
mantissa_max = 1.0
for i in range(mantissa_bits - 1):
mantissa_max += 2 ** -(i + 1)
# Calculate maximum value
max_value = mantissa_max * (2 ** (2**exp_bits - 1 - bias))
return max_value
"""
def quantize_fp8(tensor, scale, fp8_dtype, max_value, min_value):
"""
Quantize a tensor to FP8 format using PyTorch's native FP8 dtype support.
Args:
tensor (torch.Tensor): Tensor to quantize
scale (float or torch.Tensor): Scale factor
fp8_dtype (torch.dtype): Target FP8 dtype (torch.float8_e4m3fn or torch.float8_e5m2)
max_value (float): Maximum representable value in FP8
min_value (float): Minimum representable value in FP8
Returns:
torch.Tensor: Quantized tensor in FP8 format
"""
tensor = tensor.to(torch.float32) # ensure tensor is in float32 for division
# Create scaled tensor
tensor = torch.div(tensor, scale).nan_to_num_(0.0) # handle NaN values, equivalent to nonzero_mask in previous function
# Clamp tensor to range
tensor = tensor.clamp_(min=min_value, max=max_value)
# Convert to FP8 dtype
tensor = tensor.to(fp8_dtype)
return tensor
def optimize_state_dict_with_fp8(
state_dict: dict,
calc_device: Union[str, torch.device],
target_layer_keys: Optional[list[str]] = None,
exclude_layer_keys: Optional[list[str]] = None,
exp_bits: int = 4,
mantissa_bits: int = 3,
move_to_device: bool = False,
quantization_mode: str = "block",
block_size: Optional[int] = 64,
):
"""
Optimize Linear layer weights in a model's state dict to FP8 format. The state dict is modified in-place.
This function is a static version of load_safetensors_with_fp8_optimization without loading from files.
Args:
state_dict (dict): State dict to optimize, replaced in-place
calc_device (str): Device to quantize tensors on
target_layer_keys (list, optional): Layer key patterns to target (None for all Linear layers)
exclude_layer_keys (list, optional): Layer key patterns to exclude
exp_bits (int): Number of exponent bits
mantissa_bits (int): Number of mantissa bits
move_to_device (bool): Move optimized tensors to the calculating device
Returns:
dict: FP8 optimized state dict
"""
if exp_bits == 4 and mantissa_bits == 3:
fp8_dtype = torch.float8_e4m3fn
elif exp_bits == 5 and mantissa_bits == 2:
fp8_dtype = torch.float8_e5m2
else:
raise ValueError(f"Unsupported FP8 format: E{exp_bits}M{mantissa_bits}")
# Calculate FP8 max value
max_value = calculate_fp8_maxval(exp_bits, mantissa_bits)
min_value = -max_value # this function supports only signed FP8
# Create optimized state dict
optimized_count = 0
# Enumerate tarket keys
target_state_dict_keys = []
for key in state_dict.keys():
# Check if it's a weight key and matches target patterns
is_target = (target_layer_keys is None or any(pattern in key for pattern in target_layer_keys)) and key.endswith(".weight")
is_excluded = exclude_layer_keys is not None and any(pattern in key for pattern in exclude_layer_keys)
is_target = is_target and not is_excluded
if is_target and isinstance(state_dict[key], torch.Tensor):
target_state_dict_keys.append(key)
# Process each key
for key in tqdm(target_state_dict_keys):
value = state_dict[key]
# Save original device and dtype
original_device = value.device
original_dtype = value.dtype
# Move to calculation device
if calc_device is not None:
value = value.to(calc_device)
quantized_weight, scale_tensor = quantize_weight(key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size)
# Add to state dict using original key for weight and new key for scale
fp8_key = key # Maintain original key
scale_key = key.replace(".weight", ".scale_weight")
if not move_to_device:
quantized_weight = quantized_weight.to(original_device)
# keep scale shape: [1] or [out,1] or [out, num_blocks, 1]. We can determine the quantization mode from the shape of scale_weight in the patched model.
scale_tensor = scale_tensor.to(dtype=original_dtype, device=quantized_weight.device)
state_dict[fp8_key] = quantized_weight
state_dict[scale_key] = scale_tensor
optimized_count += 1
if calc_device is not None: # optimized_count % 10 == 0 and
# free memory on calculation device
clean_memory_on_device(calc_device)
logger.info(f"Number of optimized Linear layers: {optimized_count}")
return state_dict
def quantize_weight(
key: str,
tensor: torch.Tensor,
fp8_dtype: torch.dtype,
max_value: float,
min_value: float,
quantization_mode: str = "block",
block_size: int = 64,
):
original_shape = tensor.shape
# Determine quantization mode
if quantization_mode == "block":
if tensor.ndim != 2:
quantization_mode = "tensor" # fallback to per-tensor
else:
out_features, in_features = tensor.shape
if in_features % block_size != 0:
quantization_mode = "channel" # fallback to per-channel
logger.warning(
f"Layer {key} with shape {tensor.shape} is not divisible by block_size {block_size}, fallback to per-channel quantization."
)
else:
num_blocks = in_features // block_size
tensor = tensor.contiguous().view(out_features, num_blocks, block_size) # [out, num_blocks, block_size]
elif quantization_mode == "channel":
if tensor.ndim != 2:
quantization_mode = "tensor" # fallback to per-tensor
# Calculate scale factor (per-tensor or per-output-channel with percentile or max)
# value shape is expected to be [out_features, in_features] for Linear weights
if quantization_mode == "channel" or quantization_mode == "block":
# row-wise percentile to avoid being dominated by outliers
# result shape: [out_features, 1] or [out_features, num_blocks, 1]
scale_dim = 1 if quantization_mode == "channel" else 2
abs_w = torch.abs(tensor)
# shape: [out_features, 1] or [out_features, num_blocks, 1]
row_max = torch.max(abs_w, dim=scale_dim, keepdim=True).values
scale = row_max / max_value
else:
# per-tensor
tensor_max = torch.max(torch.abs(tensor).view(-1))
scale = tensor_max / max_value
# print(f"Optimizing {key} with scale: {scale}")
# numerical safety
scale = torch.clamp(scale, min=1e-8)
scale = scale.to(torch.float32) # ensure scale is in float32 for division
# Quantize weight to FP8 (scale can be scalar or [out,1], broadcasting works)
quantized_weight = quantize_fp8(tensor, scale, fp8_dtype, max_value, min_value)
# If block-wise, restore original shape
if quantization_mode == "block":
quantized_weight = quantized_weight.view(original_shape) # restore to original shape [out, in]
return quantized_weight, scale
def load_safetensors_with_fp8_optimization(
model_files: List[str],
calc_device: Union[str, torch.device],
target_layer_keys=None,
exclude_layer_keys=None,
exp_bits=4,
mantissa_bits=3,
move_to_device=False,
weight_hook=None,
quantization_mode: str = "block",
block_size: Optional[int] = 64,
disable_numpy_memmap: bool = False,
weight_transform_hooks: Optional[WeightTransformHooks] = None,
) -> dict:
"""
Load weight tensors from safetensors files and merge LoRA weights into the state dict with explicit FP8 optimization.
Args:
model_files (list[str]): List of model files to load
calc_device (str or torch.device): Device to quantize tensors on
target_layer_keys (list, optional): Layer key patterns to target for optimization (None for all Linear layers)
exclude_layer_keys (list, optional): Layer key patterns to exclude from optimization
exp_bits (int): Number of exponent bits
mantissa_bits (int): Number of mantissa bits
move_to_device (bool): Move optimized tensors to the calculating device
weight_hook (callable, optional): Function to apply to each weight tensor before optimization
quantization_mode (str): Quantization mode, "tensor", "channel", or "block"
block_size (int, optional): Block size for block-wise quantization (used if quantization_mode is "block")
disable_numpy_memmap (bool): Disable numpy memmap when loading safetensors
weight_transform_hooks (WeightTransformHooks, optional): Hooks for weight transformation during loading
Returns:
dict: FP8 optimized state dict
"""
if exp_bits == 4 and mantissa_bits == 3:
fp8_dtype = torch.float8_e4m3fn
elif exp_bits == 5 and mantissa_bits == 2:
fp8_dtype = torch.float8_e5m2
else:
raise ValueError(f"Unsupported FP8 format: E{exp_bits}M{mantissa_bits}")
# Calculate FP8 max value
max_value = calculate_fp8_maxval(exp_bits, mantissa_bits)
min_value = -max_value # this function supports only signed FP8
# Define function to determine if a key is a target key. target means fp8 optimization, not for weight hook.
def is_target_key(key):
# Check if weight key matches target patterns and does not match exclude patterns
is_target = (target_layer_keys is None or any(pattern in key for pattern in target_layer_keys)) and key.endswith(".weight")
is_excluded = exclude_layer_keys is not None and any(pattern in key for pattern in exclude_layer_keys)
return is_target and not is_excluded
# Create optimized state dict
optimized_count = 0
# Process each file
state_dict = {}
for model_file in model_files:
with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
keys = f.keys()
for key in tqdm(keys, desc=f"Loading {os.path.basename(model_file)}", unit="key"):
value = f.get_tensor(key)
# Save original device
original_device = value.device # usually cpu
if weight_hook is not None:
# Apply weight hook if provided
value = weight_hook(key, value, keep_on_calc_device=(calc_device is not None))
if not is_target_key(key):
target_device = calc_device if (calc_device is not None and move_to_device) else original_device
value = value.to(target_device)
state_dict[key] = value
continue
# Move to calculation device
if calc_device is not None:
value = value.to(calc_device)
original_dtype = value.dtype
if original_dtype.itemsize == 1:
raise ValueError(
f"Layer {key} is already in {original_dtype} format. `--fp8_scaled` optimization should not be applied. Please use fp16/bf16/float32 model weights."
+ f" / レイヤー {key} は既に{original_dtype}形式です。`--fp8_scaled` 最適化は適用できません。FP16/BF16/Float32のモデル重みを使用してください。"
)
quantized_weight, scale_tensor = quantize_weight(
key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size
)
# Add to state dict using original key for weight and new key for scale
fp8_key = key # Maintain original key
scale_key = key.replace(".weight", ".scale_weight")
assert fp8_key != scale_key, "FP8 key and scale key must be different"
if not move_to_device:
quantized_weight = quantized_weight.to(original_device)
# keep scale shape: [1] or [out,1] or [out, num_blocks, 1]. We can determine the quantization mode from the shape of scale_weight in the patched model.
scale_tensor = scale_tensor.to(dtype=original_dtype, device=quantized_weight.device)
state_dict[fp8_key] = quantized_weight
state_dict[scale_key] = scale_tensor
optimized_count += 1
if calc_device is not None and optimized_count % 10 == 0:
# free memory on calculation device
clean_memory_on_device(calc_device)
logger.info(f"Number of optimized Linear layers: {optimized_count}")
return state_dict
def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value=None):
"""
Patched forward method for Linear layers with FP8 weights.
Args:
self: Linear layer instance
x (torch.Tensor): Input tensor
use_scaled_mm (bool): Use scaled_mm for FP8 Linear layers, requires SM 8.9+ (RTX 40 series)
max_value (float): Maximum value for FP8 quantization. If None, no quantization is applied for input tensor.
Returns:
torch.Tensor: Result of linear transformation
"""
if use_scaled_mm:
# **not tested**
# _scaled_mm only works for per-tensor scale for now (per-channel scale does not work in certain cases)
if self.scale_weight.ndim != 1:
raise ValueError("scaled_mm only supports per-tensor scale_weight for now.")
input_dtype = x.dtype
original_weight_dtype = self.scale_weight.dtype
target_dtype = self.weight.dtype
# assert x.ndim == 3, "Input tensor must be 3D (batch_size, seq_len, hidden_dim)"
if max_value is None:
# no input quantization
scale_x = torch.tensor(1.0, dtype=torch.float32, device=x.device)
else:
# calculate scale factor for input tensor
scale_x = (torch.max(torch.abs(x.flatten())) / max_value).to(torch.float32)
# quantize input tensor to FP8: this seems to consume a lot of memory
fp8_max_value = torch.finfo(target_dtype).max
fp8_min_value = torch.finfo(target_dtype).min
x = quantize_fp8(x, scale_x, target_dtype, fp8_max_value, fp8_min_value)
original_shape = x.shape
x = x.reshape(-1, x.shape[-1]).to(target_dtype)
weight = self.weight.t()
scale_weight = self.scale_weight.to(torch.float32)
if self.bias is not None:
# float32 is not supported with bias in scaled_mm
o = torch._scaled_mm(x, weight, out_dtype=original_weight_dtype, bias=self.bias, scale_a=scale_x, scale_b=scale_weight)
else:
o = torch._scaled_mm(x, weight, out_dtype=input_dtype, scale_a=scale_x, scale_b=scale_weight)
o = o.reshape(original_shape[0], original_shape[1], -1) if len(original_shape) == 3 else o.reshape(original_shape[0], -1)
return o.to(input_dtype)
else:
# Dequantize the weight
original_dtype = self.scale_weight.dtype
if self.scale_weight.ndim < 3:
# per-tensor or per-channel quantization, we can broadcast
dequantized_weight = self.weight.to(original_dtype) * self.scale_weight
else:
# block-wise quantization, need to reshape weight to match scale shape for broadcasting
out_features, num_blocks, _ = self.scale_weight.shape
dequantized_weight = self.weight.to(original_dtype).contiguous().view(out_features, num_blocks, -1)
dequantized_weight = dequantized_weight * self.scale_weight
dequantized_weight = dequantized_weight.view(self.weight.shape)
# Perform linear transformation
if self.bias is not None:
output = F.linear(x, dequantized_weight, self.bias)
else:
output = F.linear(x, dequantized_weight)
return output
def apply_fp8_monkey_patch(model, optimized_state_dict, use_scaled_mm=False):
"""
Apply monkey patching to a model using FP8 optimized state dict.
Args:
model (nn.Module): Model instance to patch
optimized_state_dict (dict): FP8 optimized state dict
use_scaled_mm (bool): Use scaled_mm for FP8 Linear layers, requires SM 8.9+ (RTX 40 series)
Returns:
nn.Module: The patched model (same instance, modified in-place)
"""
# # Calculate FP8 float8_e5m2 max value
# max_value = calculate_fp8_maxval(5, 2)
max_value = None # do not quantize input tensor
# Find all scale keys to identify FP8-optimized layers
scale_keys = [k for k in optimized_state_dict.keys() if k.endswith(".scale_weight")]
# Enumerate patched layers
patched_module_paths = set()
scale_shape_info = {}
for scale_key in scale_keys:
# Extract module path from scale key (remove .scale_weight)
module_path = scale_key.rsplit(".scale_weight", 1)[0]
patched_module_paths.add(module_path)
# Store scale shape information
scale_shape_info[module_path] = optimized_state_dict[scale_key].shape
patched_count = 0
# Apply monkey patch to each layer with FP8 weights
for name, module in model.named_modules():
# Check if this module has a corresponding scale_weight
has_scale = name in patched_module_paths
# Apply patch if it's a Linear layer with FP8 scale
if isinstance(module, nn.Linear) and has_scale:
# register the scale_weight as a buffer to load the state_dict
# module.register_buffer("scale_weight", torch.tensor(1.0, dtype=module.weight.dtype))
scale_shape = scale_shape_info[name]
module.register_buffer("scale_weight", torch.ones(scale_shape, dtype=module.weight.dtype))
# Create a new forward method with the patched version.
def new_forward(self, x):
return fp8_linear_forward_patch(self, x, use_scaled_mm, max_value)
# Bind method to module
module.forward = new_forward.__get__(module, type(module))
patched_count += 1
logger.info(f"Number of monkey-patched Linear layers: {patched_count}")
return model
# Example usage
def example_usage():
# Small test model
class TestModel(nn.Module):
def __init__(self):
super().__init__()
fc1 = nn.Linear(768, 3072)
act1 = nn.GELU()
fc2 = nn.Linear(3072, 768)
act2 = nn.GELU()
fc3 = nn.Linear(768, 768)
# Set layer names for testing
self.single_blocks = nn.ModuleList([fc1, act1, fc2, act2, fc3])
self.fc4 = nn.Linear(768, 128)
def forward(self, x):
for layer in self.single_blocks:
x = layer(x)
x = self.fc4(x)
return x
# Instantiate model
test_model = TestModel()
test_model.to(torch.float16) # convert to FP16 for testing
# Test input tensor
test_input = torch.randn(1, 768, dtype=torch.float16)
# Calculate output before optimization
with torch.no_grad():
original_output = test_model(test_input)
print("original output", original_output[0, :5])
# Get state dict
state_dict = test_model.state_dict()
# Apply FP8 optimization to state dict
cuda_device = torch.device("cuda")
optimized_state_dict = optimize_state_dict_with_fp8(state_dict, cuda_device, ["single_blocks"], ["2"])
# Apply monkey patching to the model
optimized_model = TestModel() # re-instantiate model
optimized_model.to(torch.float16) # convert to FP16 for testing
apply_fp8_monkey_patch(optimized_model, optimized_state_dict)
# Load optimized state dict
optimized_model.load_state_dict(optimized_state_dict, strict=True, assign=True) # assign=True to load buffer
# Calculate output after optimization
with torch.no_grad():
optimized_output = optimized_model(test_input)
print("optimized output", optimized_output[0, :5])
# Compare accuracy
error = torch.mean(torch.abs(original_output - optimized_output))
print(f"Mean absolute error: {error.item()}")
# Check memory usage
original_params = sum(p.nelement() * p.element_size() for p in test_model.parameters()) / (1024 * 1024)
print(f"Model parameter memory: {original_params:.2f} MB")
optimized_params = sum(p.nelement() * p.element_size() for p in optimized_model.parameters()) / (1024 * 1024)
print(f"Optimized model parameter memory: {optimized_params:.2f} MB")
return test_model
if __name__ == "__main__":
example_usage()
|