Fabrice-TIERCELIN commited on
Commit
1480dd1
·
verified ·
1 Parent(s): 957014a

Delete utils

Browse files
Files changed (2) hide show
  1. utils/fp8_optimization_utils.py +0 -277
  2. utils/lora_utils.py +0 -234
utils/fp8_optimization_utils.py DELETED
@@ -1,277 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- from tqdm import tqdm
6
-
7
-
8
- def calculate_fp8_maxval(exp_bits=4, mantissa_bits=3, sign_bits=1):
9
- """
10
- Calculate the maximum representable value in FP8 format.
11
- Default is E4M3 format (4-bit exponent, 3-bit mantissa, 1-bit sign).
12
-
13
- Args:
14
- exp_bits (int): Number of exponent bits
15
- mantissa_bits (int): Number of mantissa bits
16
- sign_bits (int): Number of sign bits (0 or 1)
17
-
18
- Returns:
19
- float: Maximum value representable in FP8 format
20
- """
21
- assert exp_bits + mantissa_bits + sign_bits == 8, "Total bits must be 8"
22
-
23
- # Calculate exponent bias
24
- bias = 2 ** (exp_bits - 1) - 1
25
-
26
- # Calculate maximum mantissa value
27
- mantissa_max = 1.0
28
- for i in range(mantissa_bits - 1):
29
- mantissa_max += 2 ** -(i + 1)
30
-
31
- # Calculate maximum value
32
- max_value = mantissa_max * (2 ** (2**exp_bits - 1 - bias))
33
-
34
- return max_value
35
-
36
-
37
- def quantize_tensor_to_fp8(tensor, scale, exp_bits=4, mantissa_bits=3, sign_bits=1, max_value=None, min_value=None):
38
- """
39
- Quantize a tensor to FP8 format.
40
-
41
- Args:
42
- tensor (torch.Tensor): Tensor to quantize
43
- scale (float or torch.Tensor): Scale factor
44
- exp_bits (int): Number of exponent bits
45
- mantissa_bits (int): Number of mantissa bits
46
- sign_bits (int): Number of sign bits
47
-
48
- Returns:
49
- tuple: (quantized_tensor, scale_factor)
50
- """
51
- # Create scaled tensor
52
- scaled_tensor = tensor / scale
53
-
54
- # Calculate FP8 parameters
55
- bias = 2 ** (exp_bits - 1) - 1
56
-
57
- if max_value is None:
58
- # Calculate max and min values
59
- max_value = calculate_fp8_maxval(exp_bits, mantissa_bits, sign_bits)
60
- min_value = -max_value if sign_bits > 0 else 0.0
61
-
62
- # Clamp tensor to range
63
- clamped_tensor = torch.clamp(scaled_tensor, min_value, max_value)
64
-
65
- # Quantization process
66
- abs_values = torch.abs(clamped_tensor)
67
- nonzero_mask = abs_values > 0
68
-
69
- # Calculate logF scales (only for non-zero elements)
70
- log_scales = torch.zeros_like(clamped_tensor)
71
- if nonzero_mask.any():
72
- log_scales[nonzero_mask] = torch.floor(torch.log2(abs_values[nonzero_mask]) + bias).detach()
73
-
74
- # Limit log scales and calculate quantization factor
75
- log_scales = torch.clamp(log_scales, min=1.0)
76
- quant_factor = 2.0 ** (log_scales - mantissa_bits - bias)
77
-
78
- # Quantize and dequantize
79
- quantized = torch.round(clamped_tensor / quant_factor) * quant_factor
80
-
81
- return quantized, scale
82
-
83
-
84
- def optimize_state_dict_with_fp8(
85
- state_dict, calc_device, target_layer_keys=None, exclude_layer_keys=None, exp_bits=4, mantissa_bits=3, move_to_device=False
86
- ):
87
- """
88
- Optimize Linear layer weights in a model's state dict to FP8 format.
89
-
90
- Args:
91
- state_dict (dict): State dict to optimize, replaced in-place
92
- calc_device (str): Device to quantize tensors on
93
- target_layer_keys (list, optional): Layer key patterns to target (None for all Linear layers)
94
- exclude_layer_keys (list, optional): Layer key patterns to exclude
95
- exp_bits (int): Number of exponent bits
96
- mantissa_bits (int): Number of mantissa bits
97
- move_to_device (bool): Move optimized tensors to the calculating device
98
-
99
- Returns:
100
- dict: FP8 optimized state dict
101
- """
102
- if exp_bits == 4 and mantissa_bits == 3:
103
- fp8_dtype = torch.float8_e4m3fn
104
- elif exp_bits == 5 and mantissa_bits == 2:
105
- fp8_dtype = torch.float8_e5m2
106
- else:
107
- raise ValueError(f"Unsupported FP8 format: E{exp_bits}M{mantissa_bits}")
108
-
109
- # Calculate FP8 max value
110
- max_value = calculate_fp8_maxval(exp_bits, mantissa_bits)
111
- min_value = -max_value # this function supports only signed FP8
112
-
113
- # Create optimized state dict
114
- optimized_count = 0
115
-
116
- # Enumerate tarket keys
117
- target_state_dict_keys = []
118
- for key in state_dict.keys():
119
- # Check if it's a weight key and matches target patterns
120
- is_target = (target_layer_keys is None or any(pattern in key for pattern in target_layer_keys)) and key.endswith(".weight")
121
- is_excluded = exclude_layer_keys is not None and any(pattern in key for pattern in exclude_layer_keys)
122
- is_target = is_target and not is_excluded
123
-
124
- if is_target and isinstance(state_dict[key], torch.Tensor):
125
- target_state_dict_keys.append(key)
126
-
127
- # Process each key
128
- for key in tqdm(target_state_dict_keys):
129
- value = state_dict[key]
130
-
131
- # Save original device and dtype
132
- original_device = value.device
133
- original_dtype = value.dtype
134
-
135
- # Move to calculation device
136
- if calc_device is not None:
137
- value = value.to(calc_device)
138
-
139
- # Calculate scale factor
140
- scale = torch.max(torch.abs(value.flatten())) / max_value
141
- # print(f"Optimizing {key} with scale: {scale}")
142
-
143
- # Quantize weight to FP8
144
- quantized_weight, _ = quantize_tensor_to_fp8(value, scale, exp_bits, mantissa_bits, 1, max_value, min_value)
145
-
146
- # Add to state dict using original key for weight and new key for scale
147
- fp8_key = key # Maintain original key
148
- scale_key = key.replace(".weight", ".scale_weight")
149
-
150
- quantized_weight = quantized_weight.to(fp8_dtype)
151
-
152
- if not move_to_device:
153
- quantized_weight = quantized_weight.to(original_device)
154
-
155
- scale_tensor = torch.tensor([scale], dtype=original_dtype, device=quantized_weight.device)
156
-
157
- state_dict[fp8_key] = quantized_weight
158
- state_dict[scale_key] = scale_tensor
159
-
160
- optimized_count += 1
161
-
162
- if calc_device is not None: # optimized_count % 10 == 0 and
163
- # free memory on calculation device
164
- torch.cuda.empty_cache() # TODO check device typ
165
-
166
- print(f"Number of optimized Linear layers: {optimized_count}")
167
- return state_dict
168
-
169
-
170
- def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value=None):
171
- """
172
- Patched forward method for Linear layers with FP8 weights.
173
-
174
- Args:
175
- self: Linear layer instance
176
- x (torch.Tensor): Input tensor
177
- use_scaled_mm (bool): Use scaled_mm for FP8 Linear layers, requires SM 8.9+ (RTX 40 series)
178
- max_value (float): Maximum value for FP8 quantization. If None, no quantization is applied for input tensor.
179
-
180
- Returns:
181
- torch.Tensor: Result of linear transformation
182
- """
183
- if use_scaled_mm:
184
- input_dtype = x.dtype
185
- original_weight_dtype = self.scale_weight.dtype
186
- weight_dtype = self.weight.dtype
187
- target_dtype = torch.float8_e5m2
188
- assert weight_dtype == torch.float8_e4m3fn, "Only FP8 E4M3FN format is supported"
189
- assert x.ndim == 3, "Input tensor must be 3D (batch_size, seq_len, hidden_dim)"
190
-
191
- if max_value is None:
192
- # no input quantization
193
- scale_x = torch.tensor(1.0, dtype=torch.float32, device=x.device)
194
- else:
195
- # calculate scale factor for input tensor
196
- scale_x = (torch.max(torch.abs(x.flatten())) / max_value).to(torch.float32)
197
-
198
- # quantize input tensor to FP8: this seems to consume a lot of memory
199
- x, _ = quantize_tensor_to_fp8(x, scale_x, 5, 2, 1, max_value, -max_value)
200
-
201
- original_shape = x.shape
202
- x = x.reshape(-1, x.shape[2]).to(target_dtype)
203
-
204
- weight = self.weight.t()
205
- scale_weight = self.scale_weight.to(torch.float32)
206
-
207
- if self.bias is not None:
208
- # float32 is not supported with bias in scaled_mm
209
- o = torch._scaled_mm(x, weight, out_dtype=original_weight_dtype, bias=self.bias, scale_a=scale_x, scale_b=scale_weight)
210
- else:
211
- o = torch._scaled_mm(x, weight, out_dtype=input_dtype, scale_a=scale_x, scale_b=scale_weight)
212
-
213
- return o.reshape(original_shape[0], original_shape[1], -1).to(input_dtype)
214
-
215
- else:
216
- # Dequantize the weight
217
- original_dtype = self.scale_weight.dtype
218
- dequantized_weight = self.weight.to(original_dtype) * self.scale_weight
219
-
220
- # Perform linear transformation
221
- if self.bias is not None:
222
- output = F.linear(x, dequantized_weight, self.bias)
223
- else:
224
- output = F.linear(x, dequantized_weight)
225
-
226
- return output
227
-
228
-
229
- def apply_fp8_monkey_patch(model, optimized_state_dict, use_scaled_mm=False):
230
- """
231
- Apply monkey patching to a model using FP8 optimized state dict.
232
-
233
- Args:
234
- model (nn.Module): Model instance to patch
235
- optimized_state_dict (dict): FP8 optimized state dict
236
- use_scaled_mm (bool): Use scaled_mm for FP8 Linear layers, requires SM 8.9+ (RTX 40 series)
237
-
238
- Returns:
239
- nn.Module: The patched model (same instance, modified in-place)
240
- """
241
- # # Calculate FP8 float8_e5m2 max value
242
- # max_value = calculate_fp8_maxval(5, 2)
243
- max_value = None # do not quantize input tensor
244
-
245
- # Find all scale keys to identify FP8-optimized layers
246
- scale_keys = [k for k in optimized_state_dict.keys() if k.endswith(".scale_weight")]
247
-
248
- # Enumerate patched layers
249
- patched_module_paths = set()
250
- for scale_key in scale_keys:
251
- # Extract module path from scale key (remove .scale_weight)
252
- module_path = scale_key.rsplit(".scale_weight", 1)[0]
253
- patched_module_paths.add(module_path)
254
-
255
- patched_count = 0
256
-
257
- # Apply monkey patch to each layer with FP8 weights
258
- for name, module in model.named_modules():
259
- # Check if this module has a corresponding scale_weight
260
- has_scale = name in patched_module_paths
261
-
262
- # Apply patch if it's a Linear layer with FP8 scale
263
- if isinstance(module, nn.Linear) and has_scale:
264
- # register the scale_weight as a buffer to load the state_dict
265
- module.register_buffer("scale_weight", torch.tensor(1.0, dtype=module.weight.dtype))
266
-
267
- # Create a new forward method with the patched version.
268
- def new_forward(self, x):
269
- return fp8_linear_forward_patch(self, x, use_scaled_mm, max_value)
270
-
271
- # Bind method to module
272
- module.forward = new_forward.__get__(module, type(module))
273
-
274
- patched_count += 1
275
-
276
- print(f"Number of monkey-patched Linear layers: {patched_count}")
277
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/lora_utils.py DELETED
@@ -1,234 +0,0 @@
1
- import os
2
- import torch
3
- from safetensors.torch import load_file
4
- from tqdm import tqdm
5
-
6
-
7
- def merge_lora_to_state_dict(
8
- state_dict: dict[str, torch.Tensor], lora_file: str, multiplier: float, device: torch.device
9
- ) -> dict[str, torch.Tensor]:
10
- """
11
- Merge LoRA weights into the state dict of a model.
12
- """
13
- lora_sd = load_file(lora_file)
14
-
15
- # Check the format of the LoRA file
16
- keys = list(lora_sd.keys())
17
- if keys[0].startswith("lora_unet_"):
18
- print(f"Musubi Tuner LoRA detected")
19
- return merge_musubi_tuner(lora_sd, state_dict, multiplier, device)
20
-
21
- transformer_prefixes = ["diffusion_model", "transformer"] # to ignore Text Encoder modules
22
- lora_suffix = None
23
- prefix = None
24
- for key in keys:
25
- if lora_suffix is None and "lora_A" in key:
26
- lora_suffix = "lora_A"
27
- if prefix is None:
28
- pfx = key.split(".")[0]
29
- if pfx in transformer_prefixes:
30
- prefix = pfx
31
- if lora_suffix is not None and prefix is not None:
32
- break
33
-
34
- if lora_suffix == "lora_A" and prefix is not None:
35
- print(f"Diffusion-pipe (?) LoRA detected")
36
- return merge_diffusion_pipe_or_something(lora_sd, state_dict, "lora_unet_", multiplier, device)
37
-
38
- print(f"LoRA file format not recognized: {os.path.basename(lora_file)}")
39
- return state_dict
40
-
41
-
42
- def merge_diffusion_pipe_or_something(
43
- lora_sd: dict[str, torch.Tensor], state_dict: dict[str, torch.Tensor], prefix: str, multiplier: float, device: torch.device
44
- ) -> dict[str, torch.Tensor]:
45
- """
46
- Convert LoRA weights to the format used by the diffusion pipeline to Musubi Tuner.
47
- Copy from Musubi Tuner repo.
48
- """
49
- # convert from diffusers(?) to default LoRA
50
- # Diffusers format: {"diffusion_model.module.name.lora_A.weight": weight, "diffusion_model.module.name.lora_B.weight": weight, ...}
51
- # default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...}
52
-
53
- # note: Diffusers has no alpha, so alpha is set to rank
54
- new_weights_sd = {}
55
- lora_dims = {}
56
- for key, weight in lora_sd.items():
57
- diffusers_prefix, key_body = key.split(".", 1)
58
- if diffusers_prefix != "diffusion_model" and diffusers_prefix != "transformer":
59
- print(f"unexpected key: {key} in diffusers format")
60
- continue
61
-
62
- new_key = f"{prefix}{key_body}".replace(".", "_").replace("_lora_A_", ".lora_down.").replace("_lora_B_", ".lora_up.")
63
- new_weights_sd[new_key] = weight
64
-
65
- lora_name = new_key.split(".")[0] # before first dot
66
- if lora_name not in lora_dims and "lora_down" in new_key:
67
- lora_dims[lora_name] = weight.shape[0]
68
-
69
- # add alpha with rank
70
- for lora_name, dim in lora_dims.items():
71
- new_weights_sd[f"{lora_name}.alpha"] = torch.tensor(dim)
72
-
73
- return merge_musubi_tuner(new_weights_sd, state_dict, multiplier, device)
74
-
75
-
76
- def merge_musubi_tuner(
77
- lora_sd: dict[str, torch.Tensor], state_dict: dict[str, torch.Tensor], multiplier: float, device: torch.device
78
- ) -> dict[str, torch.Tensor]:
79
- """
80
- Merge LoRA weights into the state dict of a model.
81
- """
82
- # Check LoRA is for FramePack or for HunyuanVideo
83
- is_hunyuan = False
84
- for key in lora_sd.keys():
85
- if "double_blocks" in key or "single_blocks" in key:
86
- is_hunyuan = True
87
- break
88
- if is_hunyuan:
89
- print("HunyuanVideo LoRA detected, converting to FramePack format")
90
- lora_sd = convert_hunyuan_to_framepack(lora_sd)
91
-
92
- # Merge LoRA weights into the state dict
93
- print(f"Merging LoRA weights into state dict. multiplier: {multiplier}")
94
-
95
- # Create module map
96
- name_to_original_key = {}
97
- for key in state_dict.keys():
98
- if key.endswith(".weight"):
99
- lora_name = key.rsplit(".", 1)[0] # remove trailing ".weight"
100
- lora_name = "lora_unet_" + lora_name.replace(".", "_")
101
- if lora_name not in name_to_original_key:
102
- name_to_original_key[lora_name] = key
103
-
104
- # Merge LoRA weights
105
- keys = list([k for k in lora_sd.keys() if "lora_down" in k])
106
- for key in tqdm(keys, desc="Merging LoRA weights"):
107
- up_key = key.replace("lora_down", "lora_up")
108
- alpha_key = key[: key.index("lora_down")] + "alpha"
109
-
110
- # find original key for this lora
111
- module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
112
- if module_name not in name_to_original_key:
113
- print(f"No module found for LoRA weight: {key}")
114
- continue
115
-
116
- original_key = name_to_original_key[module_name]
117
-
118
- down_weight = lora_sd[key]
119
- up_weight = lora_sd[up_key]
120
-
121
- dim = down_weight.size()[0]
122
- alpha = lora_sd.get(alpha_key, dim)
123
- scale = alpha / dim
124
-
125
- weight = state_dict[original_key]
126
- original_device = weight.device
127
- if original_device != device:
128
- weight = weight.to(device) # to make calculation faster
129
-
130
- down_weight = down_weight.to(device)
131
- up_weight = up_weight.to(device)
132
-
133
- # W <- W + U * D
134
- if len(weight.size()) == 2:
135
- # linear
136
- if len(up_weight.size()) == 4: # use linear projection mismatch
137
- up_weight = up_weight.squeeze(3).squeeze(2)
138
- down_weight = down_weight.squeeze(3).squeeze(2)
139
- weight = weight + multiplier * (up_weight @ down_weight) * scale
140
- elif down_weight.size()[2:4] == (1, 1):
141
- # conv2d 1x1
142
- weight = (
143
- weight
144
- + multiplier
145
- * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
146
- * scale
147
- )
148
- else:
149
- # conv2d 3x3
150
- conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
151
- # logger.info(conved.size(), weight.size(), module.stride, module.padding)
152
- weight = weight + multiplier * conved * scale
153
-
154
- weight = weight.to(original_device) # move back to original device
155
- state_dict[original_key] = weight
156
-
157
- return state_dict
158
-
159
-
160
- def convert_hunyuan_to_framepack(lora_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
161
- """
162
- Convert HunyuanVideo LoRA weights to FramePack format.
163
- """
164
- new_lora_sd = {}
165
- for key, weight in lora_sd.items():
166
- if "double_blocks" in key:
167
- key = key.replace("double_blocks", "transformer_blocks")
168
- key = key.replace("img_mod_linear", "norm1_linear")
169
- key = key.replace("img_attn_qkv", "attn_to_QKV") # split later
170
- key = key.replace("img_attn_proj", "attn_to_out_0")
171
- key = key.replace("img_mlp_fc1", "ff_net_0_proj")
172
- key = key.replace("img_mlp_fc2", "ff_net_2")
173
- key = key.replace("txt_mod_linear", "norm1_context_linear")
174
- key = key.replace("txt_attn_qkv", "attn_add_QKV_proj") # split later
175
- key = key.replace("txt_attn_proj", "attn_to_add_out")
176
- key = key.replace("txt_mlp_fc1", "ff_context_net_0_proj")
177
- key = key.replace("txt_mlp_fc2", "ff_context_net_2")
178
- elif "single_blocks" in key:
179
- key = key.replace("single_blocks", "single_transformer_blocks")
180
- key = key.replace("linear1", "attn_to_QKVM") # split later
181
- key = key.replace("linear2", "proj_out")
182
- key = key.replace("modulation_linear", "norm_linear")
183
- else:
184
- print(f"Unsupported module name: {key}, only double_blocks and single_blocks are supported")
185
- continue
186
-
187
- if "QKVM" in key:
188
- # split QKVM into Q, K, V, M
189
- key_q = key.replace("QKVM", "q")
190
- key_k = key.replace("QKVM", "k")
191
- key_v = key.replace("QKVM", "v")
192
- key_m = key.replace("attn_to_QKVM", "proj_mlp")
193
- if "_down" in key or "alpha" in key:
194
- # copy QKVM weight or alpha to Q, K, V, M
195
- assert "alpha" in key or weight.size(1) == 3072, f"QKVM weight size mismatch: {key}. {weight.size()}"
196
- new_lora_sd[key_q] = weight
197
- new_lora_sd[key_k] = weight
198
- new_lora_sd[key_v] = weight
199
- new_lora_sd[key_m] = weight
200
- elif "_up" in key:
201
- # split QKVM weight into Q, K, V, M
202
- assert weight.size(0) == 21504, f"QKVM weight size mismatch: {key}. {weight.size()}"
203
- new_lora_sd[key_q] = weight[:3072]
204
- new_lora_sd[key_k] = weight[3072 : 3072 * 2]
205
- new_lora_sd[key_v] = weight[3072 * 2 : 3072 * 3]
206
- new_lora_sd[key_m] = weight[3072 * 3 :] # 21504 - 3072 * 3 = 12288
207
- else:
208
- print(f"Unsupported module name: {key}")
209
- continue
210
- elif "QKV" in key:
211
- # split QKV into Q, K, V
212
- key_q = key.replace("QKV", "q")
213
- key_k = key.replace("QKV", "k")
214
- key_v = key.replace("QKV", "v")
215
- if "_down" in key or "alpha" in key:
216
- # copy QKV weight or alpha to Q, K, V
217
- assert "alpha" in key or weight.size(1) == 3072, f"QKV weight size mismatch: {key}. {weight.size()}"
218
- new_lora_sd[key_q] = weight
219
- new_lora_sd[key_k] = weight
220
- new_lora_sd[key_v] = weight
221
- elif "_up" in key:
222
- # split QKV weight into Q, K, V
223
- assert weight.size(0) == 3072 * 3, f"QKV weight size mismatch: {key}. {weight.size()}"
224
- new_lora_sd[key_q] = weight[:3072]
225
- new_lora_sd[key_k] = weight[3072 : 3072 * 2]
226
- new_lora_sd[key_v] = weight[3072 * 2 :]
227
- else:
228
- print(f"Unsupported module name: {key}")
229
- continue
230
- else:
231
- # no split needed
232
- new_lora_sd[key] = weight
233
-
234
- return new_lora_sd