| import os |
| from typing import List, Optional, Union |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| import logging |
|
|
| from tqdm import tqdm |
|
|
| from musubi_tuner.utils.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks |
|
|
| logger = logging.getLogger(__name__) |
| logging.basicConfig(level=logging.INFO) |
|
|
| from musubi_tuner.utils.device_utils import clean_memory_on_device |
|
|
|
|
| def calculate_fp8_maxval(exp_bits=4, mantissa_bits=3, sign_bits=1): |
| """ |
| Calculate the maximum representable value in FP8 format. |
| Default is E4M3 format (4-bit exponent, 3-bit mantissa, 1-bit sign). Only supports E4M3 and E5M2 with sign bit. |
| |
| Args: |
| exp_bits (int): Number of exponent bits |
| mantissa_bits (int): Number of mantissa bits |
| sign_bits (int): Number of sign bits (0 or 1) |
| |
| Returns: |
| float: Maximum value representable in FP8 format |
| """ |
| assert exp_bits + mantissa_bits + sign_bits == 8, "Total bits must be 8" |
| if exp_bits == 4 and mantissa_bits == 3 and sign_bits == 1: |
| return torch.finfo(torch.float8_e4m3fn).max |
| elif exp_bits == 5 and mantissa_bits == 2 and sign_bits == 1: |
| return torch.finfo(torch.float8_e5m2).max |
| else: |
| raise ValueError(f"Unsupported FP8 format: E{exp_bits}M{mantissa_bits} with sign_bits={sign_bits}") |
|
|
|
|
| |
| """ |
| # Calculate exponent bias |
| bias = 2 ** (exp_bits - 1) - 1 |
| |
| # Calculate maximum mantissa value |
| mantissa_max = 1.0 |
| for i in range(mantissa_bits - 1): |
| mantissa_max += 2 ** -(i + 1) |
| |
| # Calculate maximum value |
| max_value = mantissa_max * (2 ** (2**exp_bits - 1 - bias)) |
| |
| return max_value |
| """ |
|
|
|
|
| def quantize_fp8(tensor, scale, fp8_dtype, max_value, min_value): |
| """ |
| Quantize a tensor to FP8 format using PyTorch's native FP8 dtype support. |
| |
| Args: |
| tensor (torch.Tensor): Tensor to quantize |
| scale (float or torch.Tensor): Scale factor |
| fp8_dtype (torch.dtype): Target FP8 dtype (torch.float8_e4m3fn or torch.float8_e5m2) |
| max_value (float): Maximum representable value in FP8 |
| min_value (float): Minimum representable value in FP8 |
| |
| Returns: |
| torch.Tensor: Quantized tensor in FP8 format |
| """ |
| tensor = tensor.to(torch.float32) |
|
|
| |
| tensor = torch.div(tensor, scale).nan_to_num_(0.0) |
|
|
| |
| tensor = tensor.clamp_(min=min_value, max=max_value) |
|
|
| |
| tensor = tensor.to(fp8_dtype) |
|
|
| return tensor |
|
|
|
|
| def optimize_state_dict_with_fp8( |
| state_dict: dict, |
| calc_device: Union[str, torch.device], |
| target_layer_keys: Optional[list[str]] = None, |
| exclude_layer_keys: Optional[list[str]] = None, |
| exp_bits: int = 4, |
| mantissa_bits: int = 3, |
| move_to_device: bool = False, |
| quantization_mode: str = "block", |
| block_size: Optional[int] = 64, |
| ): |
| """ |
| Optimize Linear layer weights in a model's state dict to FP8 format. The state dict is modified in-place. |
| This function is a static version of load_safetensors_with_fp8_optimization without loading from files. |
| |
| Args: |
| state_dict (dict): State dict to optimize, replaced in-place |
| calc_device (str): Device to quantize tensors on |
| target_layer_keys (list, optional): Layer key patterns to target (None for all Linear layers) |
| exclude_layer_keys (list, optional): Layer key patterns to exclude |
| exp_bits (int): Number of exponent bits |
| mantissa_bits (int): Number of mantissa bits |
| move_to_device (bool): Move optimized tensors to the calculating device |
| |
| Returns: |
| dict: FP8 optimized state dict |
| """ |
| if exp_bits == 4 and mantissa_bits == 3: |
| fp8_dtype = torch.float8_e4m3fn |
| elif exp_bits == 5 and mantissa_bits == 2: |
| fp8_dtype = torch.float8_e5m2 |
| else: |
| raise ValueError(f"Unsupported FP8 format: E{exp_bits}M{mantissa_bits}") |
|
|
| |
| max_value = calculate_fp8_maxval(exp_bits, mantissa_bits) |
| min_value = -max_value |
|
|
| |
| optimized_count = 0 |
|
|
| |
| target_state_dict_keys = [] |
| for key in state_dict.keys(): |
| |
| is_target = (target_layer_keys is None or any(pattern in key for pattern in target_layer_keys)) and key.endswith(".weight") |
| is_excluded = exclude_layer_keys is not None and any(pattern in key for pattern in exclude_layer_keys) |
| is_target = is_target and not is_excluded |
|
|
| if is_target and isinstance(state_dict[key], torch.Tensor): |
| target_state_dict_keys.append(key) |
|
|
| |
| for key in tqdm(target_state_dict_keys): |
| value = state_dict[key] |
|
|
| |
| original_device = value.device |
| original_dtype = value.dtype |
|
|
| |
| if calc_device is not None: |
| value = value.to(calc_device) |
|
|
| quantized_weight, scale_tensor = quantize_weight(key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size) |
|
|
| |
| fp8_key = key |
| scale_key = key.replace(".weight", ".scale_weight") |
|
|
| if not move_to_device: |
| quantized_weight = quantized_weight.to(original_device) |
|
|
| |
| scale_tensor = scale_tensor.to(dtype=original_dtype, device=quantized_weight.device) |
|
|
| state_dict[fp8_key] = quantized_weight |
| state_dict[scale_key] = scale_tensor |
|
|
| optimized_count += 1 |
|
|
| if calc_device is not None: |
| |
| clean_memory_on_device(calc_device) |
|
|
| logger.info(f"Number of optimized Linear layers: {optimized_count}") |
| return state_dict |
|
|
|
|
| def quantize_weight( |
| key: str, |
| tensor: torch.Tensor, |
| fp8_dtype: torch.dtype, |
| max_value: float, |
| min_value: float, |
| quantization_mode: str = "block", |
| block_size: int = 64, |
| ): |
| original_shape = tensor.shape |
|
|
| |
| if quantization_mode == "block": |
| if tensor.ndim != 2: |
| quantization_mode = "tensor" |
| else: |
| out_features, in_features = tensor.shape |
| if in_features % block_size != 0: |
| quantization_mode = "channel" |
| logger.warning( |
| f"Layer {key} with shape {tensor.shape} is not divisible by block_size {block_size}, fallback to per-channel quantization." |
| ) |
| else: |
| num_blocks = in_features // block_size |
| tensor = tensor.contiguous().view(out_features, num_blocks, block_size) |
| elif quantization_mode == "channel": |
| if tensor.ndim != 2: |
| quantization_mode = "tensor" |
|
|
| |
| |
| if quantization_mode == "channel" or quantization_mode == "block": |
| |
| |
| scale_dim = 1 if quantization_mode == "channel" else 2 |
| abs_w = torch.abs(tensor) |
|
|
| |
| row_max = torch.max(abs_w, dim=scale_dim, keepdim=True).values |
| scale = row_max / max_value |
|
|
| else: |
| |
| tensor_max = torch.max(torch.abs(tensor).view(-1)) |
| scale = tensor_max / max_value |
|
|
| |
|
|
| |
| scale = torch.clamp(scale, min=1e-8) |
| scale = scale.to(torch.float32) |
|
|
| |
| quantized_weight = quantize_fp8(tensor, scale, fp8_dtype, max_value, min_value) |
|
|
| |
| if quantization_mode == "block": |
| quantized_weight = quantized_weight.view(original_shape) |
|
|
| return quantized_weight, scale |
|
|
|
|
| def load_safetensors_with_fp8_optimization( |
| model_files: List[str], |
| calc_device: Union[str, torch.device], |
| target_layer_keys=None, |
| exclude_layer_keys=None, |
| exp_bits=4, |
| mantissa_bits=3, |
| move_to_device=False, |
| weight_hook=None, |
| quantization_mode: str = "block", |
| block_size: Optional[int] = 64, |
| disable_numpy_memmap: bool = False, |
| weight_transform_hooks: Optional[WeightTransformHooks] = None, |
| ) -> dict: |
| """ |
| Load weight tensors from safetensors files and merge LoRA weights into the state dict with explicit FP8 optimization. |
| |
| Args: |
| model_files (list[str]): List of model files to load |
| calc_device (str or torch.device): Device to quantize tensors on |
| target_layer_keys (list, optional): Layer key patterns to target for optimization (None for all Linear layers) |
| exclude_layer_keys (list, optional): Layer key patterns to exclude from optimization |
| exp_bits (int): Number of exponent bits |
| mantissa_bits (int): Number of mantissa bits |
| move_to_device (bool): Move optimized tensors to the calculating device |
| weight_hook (callable, optional): Function to apply to each weight tensor before optimization |
| quantization_mode (str): Quantization mode, "tensor", "channel", or "block" |
| block_size (int, optional): Block size for block-wise quantization (used if quantization_mode is "block") |
| disable_numpy_memmap (bool): Disable numpy memmap when loading safetensors |
| weight_transform_hooks (WeightTransformHooks, optional): Hooks for weight transformation during loading |
| |
| Returns: |
| dict: FP8 optimized state dict |
| """ |
| if exp_bits == 4 and mantissa_bits == 3: |
| fp8_dtype = torch.float8_e4m3fn |
| elif exp_bits == 5 and mantissa_bits == 2: |
| fp8_dtype = torch.float8_e5m2 |
| else: |
| raise ValueError(f"Unsupported FP8 format: E{exp_bits}M{mantissa_bits}") |
|
|
| |
| max_value = calculate_fp8_maxval(exp_bits, mantissa_bits) |
| min_value = -max_value |
|
|
| |
| def is_target_key(key): |
| |
| is_target = (target_layer_keys is None or any(pattern in key for pattern in target_layer_keys)) and key.endswith(".weight") |
| is_excluded = exclude_layer_keys is not None and any(pattern in key for pattern in exclude_layer_keys) |
| return is_target and not is_excluded |
|
|
| |
| optimized_count = 0 |
|
|
| |
| state_dict = {} |
| for model_file in model_files: |
| with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f: |
| f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f |
|
|
| keys = f.keys() |
| for key in tqdm(keys, desc=f"Loading {os.path.basename(model_file)}", unit="key"): |
| value = f.get_tensor(key) |
|
|
| |
| original_device = value.device |
|
|
| if weight_hook is not None: |
| |
| value = weight_hook(key, value, keep_on_calc_device=(calc_device is not None)) |
|
|
| if not is_target_key(key): |
| target_device = calc_device if (calc_device is not None and move_to_device) else original_device |
| value = value.to(target_device) |
| state_dict[key] = value |
| continue |
|
|
| |
| if calc_device is not None: |
| value = value.to(calc_device) |
|
|
| original_dtype = value.dtype |
| if original_dtype.itemsize == 1: |
| raise ValueError( |
| f"Layer {key} is already in {original_dtype} format. `--fp8_scaled` optimization should not be applied. Please use fp16/bf16/float32 model weights." |
| + f" / レイヤー {key} は既に{original_dtype}形式です。`--fp8_scaled` 最適化は適用できません。FP16/BF16/Float32のモデル重みを使用してください。" |
| ) |
| quantized_weight, scale_tensor = quantize_weight( |
| key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size |
| ) |
|
|
| |
| fp8_key = key |
| scale_key = key.replace(".weight", ".scale_weight") |
| assert fp8_key != scale_key, "FP8 key and scale key must be different" |
|
|
| if not move_to_device: |
| quantized_weight = quantized_weight.to(original_device) |
|
|
| |
| scale_tensor = scale_tensor.to(dtype=original_dtype, device=quantized_weight.device) |
|
|
| state_dict[fp8_key] = quantized_weight |
| state_dict[scale_key] = scale_tensor |
|
|
| optimized_count += 1 |
|
|
| if calc_device is not None and optimized_count % 10 == 0: |
| |
| clean_memory_on_device(calc_device) |
|
|
| logger.info(f"Number of optimized Linear layers: {optimized_count}") |
| return state_dict |
|
|
|
|
| def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value=None): |
| """ |
| Patched forward method for Linear layers with FP8 weights. |
| |
| Args: |
| self: Linear layer instance |
| x (torch.Tensor): Input tensor |
| use_scaled_mm (bool): Use scaled_mm for FP8 Linear layers, requires SM 8.9+ (RTX 40 series) |
| max_value (float): Maximum value for FP8 quantization. If None, no quantization is applied for input tensor. |
| |
| Returns: |
| torch.Tensor: Result of linear transformation |
| """ |
| if use_scaled_mm: |
| |
| |
| if self.scale_weight.ndim != 1: |
| raise ValueError("scaled_mm only supports per-tensor scale_weight for now.") |
|
|
| input_dtype = x.dtype |
| original_weight_dtype = self.scale_weight.dtype |
| target_dtype = self.weight.dtype |
| |
|
|
| if max_value is None: |
| |
| scale_x = torch.tensor(1.0, dtype=torch.float32, device=x.device) |
| else: |
| |
| scale_x = (torch.max(torch.abs(x.flatten())) / max_value).to(torch.float32) |
|
|
| |
| fp8_max_value = torch.finfo(target_dtype).max |
| fp8_min_value = torch.finfo(target_dtype).min |
| x = quantize_fp8(x, scale_x, target_dtype, fp8_max_value, fp8_min_value) |
|
|
| original_shape = x.shape |
| x = x.reshape(-1, x.shape[-1]).to(target_dtype) |
|
|
| weight = self.weight.t() |
| scale_weight = self.scale_weight.to(torch.float32) |
|
|
| if self.bias is not None: |
| |
| o = torch._scaled_mm(x, weight, out_dtype=original_weight_dtype, bias=self.bias, scale_a=scale_x, scale_b=scale_weight) |
| else: |
| o = torch._scaled_mm(x, weight, out_dtype=input_dtype, scale_a=scale_x, scale_b=scale_weight) |
|
|
| o = o.reshape(original_shape[0], original_shape[1], -1) if len(original_shape) == 3 else o.reshape(original_shape[0], -1) |
| return o.to(input_dtype) |
|
|
| else: |
| |
| original_dtype = self.scale_weight.dtype |
| if self.scale_weight.ndim < 3: |
| |
| dequantized_weight = self.weight.to(original_dtype) * self.scale_weight |
| else: |
| |
| out_features, num_blocks, _ = self.scale_weight.shape |
| dequantized_weight = self.weight.to(original_dtype).contiguous().view(out_features, num_blocks, -1) |
| dequantized_weight = dequantized_weight * self.scale_weight |
| dequantized_weight = dequantized_weight.view(self.weight.shape) |
|
|
| |
| if self.bias is not None: |
| output = F.linear(x, dequantized_weight, self.bias) |
| else: |
| output = F.linear(x, dequantized_weight) |
|
|
| return output |
|
|
|
|
| def apply_fp8_monkey_patch(model, optimized_state_dict, use_scaled_mm=False): |
| """ |
| Apply monkey patching to a model using FP8 optimized state dict. |
| |
| Args: |
| model (nn.Module): Model instance to patch |
| optimized_state_dict (dict): FP8 optimized state dict |
| use_scaled_mm (bool): Use scaled_mm for FP8 Linear layers, requires SM 8.9+ (RTX 40 series) |
| |
| Returns: |
| nn.Module: The patched model (same instance, modified in-place) |
| """ |
| |
| |
| max_value = None |
|
|
| |
| scale_keys = [k for k in optimized_state_dict.keys() if k.endswith(".scale_weight")] |
|
|
| |
| patched_module_paths = set() |
| scale_shape_info = {} |
| for scale_key in scale_keys: |
| |
| module_path = scale_key.rsplit(".scale_weight", 1)[0] |
| patched_module_paths.add(module_path) |
|
|
| |
| scale_shape_info[module_path] = optimized_state_dict[scale_key].shape |
|
|
| patched_count = 0 |
|
|
| |
| for name, module in model.named_modules(): |
| |
| has_scale = name in patched_module_paths |
|
|
| |
| if isinstance(module, nn.Linear) and has_scale: |
| |
| |
| scale_shape = scale_shape_info[name] |
| module.register_buffer("scale_weight", torch.ones(scale_shape, dtype=module.weight.dtype)) |
|
|
| |
| def new_forward(self, x): |
| return fp8_linear_forward_patch(self, x, use_scaled_mm, max_value) |
|
|
| |
| module.forward = new_forward.__get__(module, type(module)) |
|
|
| patched_count += 1 |
|
|
| logger.info(f"Number of monkey-patched Linear layers: {patched_count}") |
| return model |
|
|
|
|
| |
| def example_usage(): |
| |
| class TestModel(nn.Module): |
| def __init__(self): |
| super().__init__() |
| fc1 = nn.Linear(768, 3072) |
| act1 = nn.GELU() |
| fc2 = nn.Linear(3072, 768) |
| act2 = nn.GELU() |
| fc3 = nn.Linear(768, 768) |
|
|
| |
| self.single_blocks = nn.ModuleList([fc1, act1, fc2, act2, fc3]) |
|
|
| self.fc4 = nn.Linear(768, 128) |
|
|
| def forward(self, x): |
| for layer in self.single_blocks: |
| x = layer(x) |
| x = self.fc4(x) |
| return x |
|
|
| |
| test_model = TestModel() |
| test_model.to(torch.float16) |
|
|
| |
| test_input = torch.randn(1, 768, dtype=torch.float16) |
|
|
| |
| with torch.no_grad(): |
| original_output = test_model(test_input) |
| print("original output", original_output[0, :5]) |
|
|
| |
| state_dict = test_model.state_dict() |
|
|
| |
| cuda_device = torch.device("cuda") |
| optimized_state_dict = optimize_state_dict_with_fp8(state_dict, cuda_device, ["single_blocks"], ["2"]) |
|
|
| |
| optimized_model = TestModel() |
| optimized_model.to(torch.float16) |
| apply_fp8_monkey_patch(optimized_model, optimized_state_dict) |
|
|
| |
| optimized_model.load_state_dict(optimized_state_dict, strict=True, assign=True) |
|
|
| |
| with torch.no_grad(): |
| optimized_output = optimized_model(test_input) |
| print("optimized output", optimized_output[0, :5]) |
|
|
| |
| error = torch.mean(torch.abs(original_output - optimized_output)) |
| print(f"Mean absolute error: {error.item()}") |
|
|
| |
| original_params = sum(p.nelement() * p.element_size() for p in test_model.parameters()) / (1024 * 1024) |
| print(f"Model parameter memory: {original_params:.2f} MB") |
| optimized_params = sum(p.nelement() * p.element_size() for p in optimized_model.parameters()) / (1024 * 1024) |
| print(f"Optimized model parameter memory: {optimized_params:.2f} MB") |
|
|
| return test_model |
|
|
|
|
| if __name__ == "__main__": |
| example_usage() |
|
|