|
|
| """
|
| Convert a Flux model from Diffusers (folder or single-file) into the original
|
| single-file Flux transformer checkpoint used by Black Forest Labs / ComfyUI.
|
|
|
| Input : /path/to/diffusers (root or .../transformer) OR /path/to/*.safetensors (single file)
|
| Output : /path/to/flux1-your-model.safetensors (transformer only)
|
|
|
| Usage:
|
| python diffusers_to_flux_transformer.py /path/to/diffusers /out/flux1-dev.safetensors
|
| python diffusers_to_flux_transformer.py /path/to/diffusion_pytorch_model.safetensors /out/flux1-dev.safetensors
|
| # optional quantization:
|
| # --fp8 (float8_e4m3fn, simple)
|
| # --fp8-scaled (scaled float8 for 2D weights; adds .scale_weight tensors)
|
| """
|
|
|
| import argparse
|
| import json
|
| from pathlib import Path
|
| from collections import OrderedDict
|
|
|
| import torch
|
| from safetensors import safe_open
|
| import safetensors.torch
|
| from tqdm import tqdm
|
|
|
|
|
| def parse_args():
|
| ap = argparse.ArgumentParser()
|
| ap.add_argument("diffusers_path", type=str,
|
| help="Path to Diffusers checkpoint folder OR a single .safetensors file.")
|
| ap.add_argument("output_path", type=str,
|
| help="Output .safetensors path for the Flux transformer.")
|
| ap.add_argument("--fp8", action="store_true",
|
| help="Experimental: write weights as float8_e4m3fn via stochastic rounding (transformer only).")
|
| ap.add_argument("--fp8-scaled", action="store_true",
|
| help="Experimental: scaled float8_e4m3fn for 2D weight tensors; adds .scale_weight tensors.")
|
| return ap.parse_args()
|
|
|
|
|
|
|
| DIFFUSERS_MAP = {
|
|
|
| "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"],
|
| "time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"],
|
| "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"],
|
| "time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"],
|
|
|
| "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"],
|
| "vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"],
|
| "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"],
|
| "vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"],
|
|
|
| "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"],
|
| "guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"],
|
| "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"],
|
| "guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"],
|
|
|
| "txt_in.weight": ["context_embedder.weight"],
|
| "txt_in.bias": ["context_embedder.bias"],
|
| "img_in.weight": ["x_embedder.weight"],
|
| "img_in.bias": ["x_embedder.bias"],
|
|
|
|
|
| "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"],
|
| "double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"],
|
| "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"],
|
| "double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"],
|
|
|
| "double_blocks.().img_attn.qkv.weight": [
|
| ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"],
|
| ["qkv_proj.weight"],
|
| ],
|
| "double_blocks.().img_attn.qkv.bias": [
|
| ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"],
|
| ["qkv_proj.bias"],
|
| ],
|
| "double_blocks.().txt_attn.qkv.weight": [
|
| ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"],
|
| ["qkv_proj_context.weight"],
|
| ],
|
| "double_blocks.().txt_attn.qkv.bias": [
|
| ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"],
|
| ["qkv_proj_context.bias"],
|
| ],
|
|
|
| "double_blocks.().img_attn.norm.query_norm.scale": [
|
| ["attn.norm_q.weight"],
|
| ["norm_q.weight"],
|
| ],
|
| "double_blocks.().img_attn.norm.key_norm.scale": [
|
| ["attn.norm_k.weight"],
|
| ["norm_k.weight"],
|
| ],
|
| "double_blocks.().txt_attn.norm.query_norm.scale": [
|
| ["attn.norm_added_q.weight"],
|
| ["norm_added_q.weight"],
|
| ],
|
| "double_blocks.().txt_attn.norm.key_norm.scale": [
|
| ["attn.norm_added_k.weight"],
|
| ["norm_added_k.weight"],
|
| ],
|
|
|
| "double_blocks.().img_mlp.0.weight": [
|
| ["ff.net.0.proj.weight"],
|
| ["mlp_fc1.weight"],
|
| ],
|
| "double_blocks.().img_mlp.0.bias": [
|
| ["ff.net.0.proj.bias"],
|
| ["mlp_fc1.bias"],
|
| ],
|
| "double_blocks.().img_mlp.2.weight": [
|
| ["ff.net.2.weight"],
|
| ["mlp_fc2.weight"],
|
| ],
|
| "double_blocks.().img_mlp.2.bias": [
|
| ["ff.net.2.bias"],
|
| ["mlp_fc2.bias"],
|
| ],
|
|
|
| "double_blocks.().txt_mlp.0.weight": [
|
| ["ff_context.net.0.proj.weight"],
|
| ["mlp_context_fc1.weight"],
|
| ],
|
| "double_blocks.().txt_mlp.0.bias": [
|
| ["ff_context.net.0.proj.bias"],
|
| ["mlp_context_fc1.bias"],
|
| ],
|
| "double_blocks.().txt_mlp.2.weight": [
|
| ["ff_context.net.2.weight"],
|
| ["mlp_context_fc2.weight"],
|
| ],
|
| "double_blocks.().txt_mlp.2.bias": [
|
| ["ff_context.net.2.bias"],
|
| ["mlp_context_fc2.bias"],
|
| ],
|
|
|
| "double_blocks.().img_attn.proj.weight": [
|
| ["attn.to_out.0.weight"],
|
| ["out_proj.weight"],
|
| ],
|
| "double_blocks.().img_attn.proj.bias": [
|
| ["attn.to_out.0.bias"],
|
| ["out_proj.bias"],
|
| ],
|
| "double_blocks.().txt_attn.proj.weight": [
|
| ["attn.to_add_out.weight"],
|
| ["out_proj_context.weight"],
|
| ],
|
| "double_blocks.().txt_attn.proj.bias": [
|
| ["attn.to_add_out.bias"],
|
| ["out_proj_context.bias"],
|
| ],
|
|
|
|
|
| "single_blocks.().modulation.lin.weight": ["norm.linear.weight"],
|
| "single_blocks.().modulation.lin.bias": ["norm.linear.bias"],
|
| "single_blocks.().linear1.weight": [
|
| ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"],
|
| ["qkv_proj.weight", "mlp_fc1.weight"],
|
| ],
|
| "single_blocks.().linear1.bias": [
|
| ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"],
|
| ["qkv_proj.bias", "mlp_fc1.bias"],
|
| ],
|
| "single_blocks.().norm.query_norm.scale": [
|
| ["attn.norm_q.weight"],
|
| ["norm_q.weight"],
|
| ],
|
| "single_blocks.().norm.key_norm.scale": [
|
| ["attn.norm_k.weight"],
|
| ["norm_k.weight"],
|
| ],
|
| "single_blocks.().linear2.weight": [
|
| ["proj_out.weight"],
|
| ["out_proj.weight", "mlp_fc2.weight"],
|
| ],
|
| "single_blocks.().linear2.bias": [
|
| ["proj_out.bias"],
|
| ["out_proj.bias", "mlp_fc2.bias"],
|
| ],
|
|
|
|
|
| "final_layer.linear.weight": ["proj_out.weight"],
|
| "final_layer.linear.bias": ["proj_out.bias"],
|
|
|
| "final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"],
|
| "final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"],
|
| }
|
|
|
| _TARGET_SUFFIXES = {}
|
| for _tgt_key in DIFFUSERS_MAP:
|
| _base, _suffix = _tgt_key.rsplit(".", 1)
|
| _TARGET_SUFFIXES.setdefault(_base, set()).add("." + _suffix)
|
|
|
|
|
| def _strip_prefix(key: str) -> str:
|
| return key[6:] if key.startswith("model.") else key
|
|
|
|
|
| class StateDictSource:
|
| """
|
| Provide DiffusersSource-like access over an in-memory state dict.
|
| """
|
|
|
| POSSIBLE_PREFIXES = ["", "model."]
|
|
|
| def __init__(self, state_dict: dict):
|
| self._state_dict = state_dict
|
| self._all_keys = list(state_dict.keys())
|
|
|
| def _resolve(self, want: str):
|
| for pref in self.POSSIBLE_PREFIXES:
|
| key = pref + want
|
| if key in self._state_dict:
|
| return key
|
| return None
|
|
|
| def has(self, want: str) -> bool:
|
| return self._resolve(want) is not None
|
|
|
| def get(self, want: str) -> torch.Tensor:
|
| real_key = self._resolve(want)
|
| if real_key is None:
|
| raise KeyError(f"Missing key: {want}")
|
| return self._state_dict[real_key]
|
|
|
| @property
|
| def base_keys(self):
|
| return [_strip_prefix(k) for k in self._all_keys]
|
|
|
|
|
| def detect_diffusers_state_dict(state_dict: dict) -> bool:
|
| base_keys = [_strip_prefix(k) for k in state_dict.keys()]
|
| if any(k.startswith(("double_blocks.", "single_blocks.")) for k in base_keys):
|
| return False
|
| return any(k.startswith(("transformer_blocks.", "single_transformer_blocks.")) for k in base_keys)
|
|
|
|
|
| def convert_state_dict(state_dict: dict, *, verbose: bool = False) -> dict:
|
| if not detect_diffusers_state_dict(state_dict):
|
| return state_dict
|
| converted = _convert_from_source(StateDictSource(state_dict), verbose=verbose)
|
| return converted if converted else state_dict
|
|
|
|
|
| def _count_blocks(base_keys):
|
| num_dual = 0
|
| num_single = 0
|
| for key in base_keys:
|
| if key.startswith("transformer_blocks."):
|
| try:
|
| idx = int(key.split(".")[1])
|
| num_dual = max(num_dual, idx + 1)
|
| except Exception:
|
| pass
|
| elif key.startswith("single_transformer_blocks."):
|
| try:
|
| idx = int(key.split(".")[1])
|
| num_single = max(num_single, idx + 1)
|
| except Exception:
|
| pass
|
| return num_dual, num_single
|
|
|
|
|
| def _swap_scale_shift(vec: torch.Tensor) -> torch.Tensor:
|
| if vec is None or vec.ndim != 1 or vec.numel() % 2 != 0:
|
| return vec
|
| shift, scale = vec.chunk(2, dim=0)
|
| return torch.cat([scale, shift], dim=0)
|
|
|
|
|
| def _swap_scale_shift_matrix(mat: torch.Tensor) -> torch.Tensor:
|
| if mat is None or mat.ndim != 2 or mat.size(0) % 2 != 0:
|
| return mat
|
| shift, scale = mat.chunk(2, dim=0)
|
| return torch.cat([scale, shift], dim=0)
|
|
|
|
|
| def _collect_suffixes(base_keys, base):
|
| prefix = base + "."
|
| suffixes = set()
|
| for key in base_keys:
|
| if key.startswith(prefix):
|
| suffixes.add("." + key[len(prefix):])
|
| return suffixes
|
|
|
|
|
| def _normalize_suffix(suffix: str) -> str:
|
| if suffix == ".smooth":
|
| return ".smooth_factor"
|
| if suffix == ".smooth_orig":
|
| return ".smooth_factor_orig"
|
| if suffix == ".lora_down":
|
| return ".proj_down"
|
| if suffix == ".lora_up":
|
| return ".proj_up"
|
| return suffix
|
|
|
|
|
| def _match_dtype_device(ref: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
|
| if ref.dtype != other.dtype or ref.device != other.device:
|
| return other.to(device=ref.device, dtype=ref.dtype)
|
| return other
|
|
|
|
|
| def _concat(values, dim=0):
|
| if any(v is None for v in values):
|
| return None
|
| ref = values[0]
|
| if ref.ndim == 0:
|
| return ref
|
| merged = [ref] + [_match_dtype_device(ref, v) for v in values[1:]]
|
| return torch.cat(merged, dim=dim)
|
|
|
|
|
| def _maybe_load_nunchaku():
|
| try:
|
| from shared.qtypes import nunchaku_int4 as _nunchaku_int4
|
| except Exception:
|
| return None
|
| return _nunchaku_int4
|
|
|
|
|
| def _get_qweight_dims(qweight: torch.Tensor | None) -> tuple[int | None, int | None]:
|
| if not torch.is_tensor(qweight):
|
| return None, None
|
| if qweight.dtype == torch.int8:
|
| out_features = qweight.size(0)
|
| else:
|
| out_features = qweight.size(0) * 4
|
| in_features = qweight.size(1) * 2
|
| return out_features, in_features
|
|
|
|
|
| def _merge_packed_scales_out(values, qweights, group_size: int = 64):
|
| if any(v is None for v in values):
|
| return None
|
| nunchaku = _maybe_load_nunchaku()
|
| if nunchaku is None or any(q is None for q in qweights):
|
| return _concat(values, dim=1)
|
| unpacked = []
|
| out_total = 0
|
| in_features = None
|
| for value, qweight in zip(values, qweights):
|
| out_i, in_i = _get_qweight_dims(qweight)
|
| if out_i is None or in_i is None:
|
| return _concat(values, dim=1)
|
| if in_features is None:
|
| in_features = in_i
|
| if in_i != in_features:
|
| return _concat(values, dim=1)
|
| unpacked_i = nunchaku._unpack_nunchaku_wscales(value, out_i, in_i, group_size)
|
| if not torch.is_tensor(unpacked_i):
|
| return _concat(values, dim=1)
|
| unpacked.append(unpacked_i)
|
| out_total += out_i
|
| merged = [unpacked[0]] + [_match_dtype_device(unpacked[0], u) for u in unpacked[1:]]
|
| merged = torch.cat(merged, dim=1)
|
| return nunchaku._pack_nunchaku_wscales(merged, out_total, in_features, group_size)
|
|
|
|
|
| def _merge_packed_scales_in(values, qweights, group_size: int = 64):
|
| if any(v is None for v in values):
|
| return None
|
| nunchaku = _maybe_load_nunchaku()
|
| if nunchaku is None or any(q is None for q in qweights):
|
| return _concat(values, dim=0)
|
| out_a, in_a = _get_qweight_dims(qweights[0])
|
| out_b, in_b = _get_qweight_dims(qweights[1])
|
| if out_a is None or in_a is None or out_b is None or in_b is None:
|
| return _concat(values, dim=0)
|
| if out_a != out_b:
|
| return _concat(values, dim=0)
|
| unpack_a = nunchaku._unpack_nunchaku_wscales(values[0], out_a, in_a, group_size)
|
| unpack_b = nunchaku._unpack_nunchaku_wscales(values[1], out_b, in_b, group_size)
|
| if not torch.is_tensor(unpack_a) or not torch.is_tensor(unpack_b):
|
| return _concat(values, dim=0)
|
| if unpack_a.dtype != unpack_b.dtype or unpack_a.device != unpack_b.device:
|
| unpack_b = unpack_b.to(device=unpack_a.device, dtype=unpack_a.dtype)
|
| merged = torch.cat([unpack_a, unpack_b], dim=0)
|
| return nunchaku._pack_nunchaku_wscales(merged, out_a, in_a + in_b, group_size)
|
|
|
|
|
| def _pad_to_multiple(tensor: torch.Tensor | None, divisor: int | tuple[int, int]):
|
| if tensor is None:
|
| return None
|
| if isinstance(divisor, int):
|
| div0 = div1 = divisor
|
| else:
|
| div0, div1 = divisor
|
| height, width = tensor.shape
|
| new_h = ((height + div0 - 1) // div0) * div0
|
| new_w = ((width + div1 - 1) // div1) * div1
|
| if new_h == height and new_w == width:
|
| return tensor
|
| padded = torch.zeros((new_h, new_w), dtype=tensor.dtype, device=tensor.device)
|
| padded[:height, :width] = tensor
|
| return padded
|
|
|
|
|
| def _pack_lowrank_weight(weight: torch.Tensor | None, down: bool):
|
| if weight is None or weight.ndim != 2:
|
| return weight
|
| lane_n, lane_k = 1, 2
|
| n_pack_size, k_pack_size = 2, 2
|
| num_n_lanes, num_k_lanes = 8, 4
|
| frag_n = n_pack_size * num_n_lanes * lane_n
|
| frag_k = k_pack_size * num_k_lanes * lane_k
|
| weight = _pad_to_multiple(weight, (frag_n, frag_k))
|
| if weight is None:
|
| return None
|
| if down:
|
| rows, cols = weight.shape
|
| r_frags, c_frags = rows // frag_n, cols // frag_k
|
| weight = weight.view(r_frags, frag_n, c_frags, frag_k).permute(2, 0, 1, 3)
|
| else:
|
| cols, rows = weight.shape
|
| c_frags, r_frags = cols // frag_n, rows // frag_k
|
| weight = weight.view(c_frags, frag_n, r_frags, frag_k).permute(0, 2, 1, 3)
|
| weight = weight.reshape(c_frags, r_frags, n_pack_size, num_n_lanes, k_pack_size, num_k_lanes, lane_k)
|
| weight = weight.permute(0, 1, 3, 5, 2, 4, 6).contiguous()
|
| return weight.view(cols, rows)
|
|
|
|
|
| def _pack_nunchaku_w4a4_weight(qvals, out_features, in_features):
|
| if qvals is None or qvals.ndim != 2:
|
| return qvals
|
| if qvals.dtype not in (torch.int8, torch.int16, torch.int32):
|
| qvals = qvals.to(torch.int32)
|
| if qvals.dtype != torch.int32:
|
| qvals = qvals.to(torch.int32)
|
| if qvals.shape != (out_features, in_features):
|
| return None
|
| mem_n = 128
|
| mem_k = 64
|
| num_k_unrolls = 2
|
| if out_features % mem_n != 0 or in_features % (mem_k * num_k_unrolls) != 0:
|
| return None
|
| n_pack_size = 2
|
| k_pack_size = 2
|
| num_n_lanes = 8
|
| num_k_lanes = 4
|
| reg_n = 1
|
| reg_k = 8
|
| num_n_packs = mem_n // (n_pack_size * num_n_lanes * reg_n)
|
| num_k_packs = mem_k // (k_pack_size * num_k_lanes * reg_k)
|
| n_tiles = out_features // mem_n
|
| k_tiles = in_features // mem_k
|
| weight = qvals.reshape(
|
| n_tiles,
|
| num_n_packs,
|
| n_pack_size,
|
| num_n_lanes,
|
| reg_n,
|
| k_tiles,
|
| num_k_packs,
|
| k_pack_size,
|
| num_k_lanes,
|
| reg_k,
|
| )
|
| weight = weight.permute(0, 5, 6, 1, 3, 8, 2, 7, 4, 9).contiguous()
|
| weight = weight.bitwise_and_(0xF)
|
| shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=weight.device)
|
| weight = weight.bitwise_left_shift_(shifts)
|
| weight = weight.sum(dim=-1, dtype=torch.int32)
|
| return weight.view(dtype=torch.int8).view(out_features, -1)
|
|
|
|
|
| def _unpack_lowrank_weight(weight: torch.Tensor | None, down: bool):
|
| if weight is None:
|
| return None
|
| nunchaku = _maybe_load_nunchaku()
|
| if nunchaku is None:
|
| return weight
|
| return nunchaku._unpack_lowrank_weight(weight, down)
|
|
|
|
|
| def _block_diag_out(mats):
|
| if not mats:
|
| return None
|
| ref = mats[0]
|
| total_rows = sum(m.size(0) for m in mats)
|
| total_cols = sum(m.size(1) for m in mats)
|
| out = torch.zeros((total_rows, total_cols), dtype=ref.dtype, device=ref.device)
|
| row = 0
|
| col = 0
|
| for mat in mats:
|
| mat = _match_dtype_device(ref, mat)
|
| out[row : row + mat.size(0), col : col + mat.size(1)] = mat
|
| row += mat.size(0)
|
| col += mat.size(1)
|
| return out
|
|
|
|
|
| def _merge_lowrank_down(values):
|
| if any(v is None for v in values):
|
| return None
|
| unpacked = [_unpack_lowrank_weight(v, down=True) for v in values]
|
| ref = unpacked[0]
|
| merged = [ref] + [_match_dtype_device(ref, v) for v in unpacked[1:]]
|
| merged = torch.cat(merged, dim=0)
|
| return _pack_lowrank_weight(merged, down=True)
|
|
|
|
|
| def _merge_lowrank_up(values):
|
| if any(v is None for v in values):
|
| return None
|
| unpacked = [_unpack_lowrank_weight(v, down=False) for v in values]
|
| merged = _block_diag_out(unpacked)
|
| return _pack_lowrank_weight(merged, down=False)
|
|
|
|
|
| def _merge_lowrank_down_block_diag(values):
|
| if any(v is None for v in values):
|
| return None
|
| unpacked = [_unpack_lowrank_weight(v, down=True) for v in values]
|
| merged = _block_diag_out(unpacked)
|
| return _pack_lowrank_weight(merged, down=True)
|
|
|
|
|
| def _merge_lowrank_up_concat(values):
|
| if any(v is None for v in values):
|
| return None
|
| unpacked = [_unpack_lowrank_weight(v, down=False) for v in values]
|
| ref = unpacked[0]
|
| merged = [ref] + [_match_dtype_device(ref, v) for v in unpacked[1:]]
|
| merged = torch.cat(merged, dim=1)
|
| return _pack_lowrank_weight(merged, down=False)
|
|
|
|
|
| def _merge_qweight_in(values):
|
| if any(v is None for v in values):
|
| return None
|
| a, b = values
|
| out_a, in_a = _get_qweight_dims(a)
|
| out_b, in_b = _get_qweight_dims(b)
|
| if out_a is None or in_a is None or out_b is None or in_b is None:
|
| return _concat(values, dim=1)
|
| if out_a != out_b:
|
| return _concat(values, dim=1)
|
| nunchaku = _maybe_load_nunchaku()
|
| if nunchaku is None or a.dtype != torch.int8 or b.dtype != torch.int8:
|
| return _concat(values, dim=1)
|
| unpack_a = nunchaku._unpack_nunchaku_w4a4_weight(a, out_a, in_a)
|
| unpack_b = nunchaku._unpack_nunchaku_w4a4_weight(b, out_b, in_b)
|
| if not torch.is_tensor(unpack_a) or not torch.is_tensor(unpack_b):
|
| return _concat(values, dim=1)
|
| if unpack_a.dtype != torch.int32:
|
| unpack_a = unpack_a.to(torch.int32)
|
| if unpack_b.dtype != torch.int32:
|
| unpack_b = unpack_b.to(torch.int32)
|
| if unpack_a.dtype != unpack_b.dtype or unpack_a.device != unpack_b.device:
|
| unpack_b = unpack_b.to(device=unpack_a.device, dtype=unpack_a.dtype)
|
| merged = torch.cat([unpack_a, unpack_b], dim=1)
|
| packed = _pack_nunchaku_w4a4_weight(merged, out_a, in_a + in_b)
|
| if packed is None:
|
| return _concat(values, dim=1)
|
| return packed
|
|
|
|
|
| def _merge_multi(values, suffix, qweights):
|
| if any(v is None for v in values):
|
| return None
|
| if suffix in (".wscales", ".wzeros"):
|
| return _merge_packed_scales_out(values, qweights)
|
| if suffix == ".proj_down":
|
| return _merge_lowrank_down(values)
|
| if suffix == ".proj_up":
|
| return _merge_lowrank_up(values)
|
| if suffix in (".smooth_factor", ".smooth_factor_orig", ".input_scale", ".output_scale", ".scale_weight"):
|
| return values[0]
|
| return _concat(values, dim=0)
|
|
|
|
|
| def _merge_multi_in(values, suffix, qweights):
|
| if any(v is None for v in values):
|
| return None
|
| if suffix == ".qweight":
|
| return _merge_qweight_in(values)
|
| if suffix in (".wscales", ".wzeros"):
|
| return _merge_packed_scales_in(values, qweights)
|
| if suffix == ".proj_down":
|
| return _merge_lowrank_down_block_diag(values)
|
| if suffix == ".proj_up":
|
| return _merge_lowrank_up_concat(values)
|
| if suffix in (".smooth_factor", ".smooth_factor_orig"):
|
| return _concat(values, dim=0)
|
| if suffix == ".bias":
|
| ref = values[0]
|
| total = ref
|
| for val in values[1:]:
|
| total = total + _match_dtype_device(ref, val)
|
| return total
|
| if suffix in (".input_scale", ".output_scale", ".scale_weight"):
|
| return values[0]
|
| return _concat(values, dim=1)
|
|
|
|
|
| def _convert_from_source(src, *, verbose: bool = False) -> dict:
|
| base_keys = src.base_keys
|
| num_dual, num_single = _count_blocks(base_keys)
|
| if verbose:
|
| print(f"Found {num_dual} dual-stream blocks, {num_single} single-stream blocks")
|
|
|
| out = {}
|
| quant_suffixes = {
|
| ".qweight",
|
| ".wscales",
|
| ".wzeros",
|
| ".smooth",
|
| ".smooth_orig",
|
| ".lora_down",
|
| ".lora_up",
|
| ".bias",
|
| ".input_scale",
|
| ".output_scale",
|
| ".scale_weight",
|
| }
|
|
|
| def _map_entry(src_prefix, tgt_template, dvals):
|
| tgt_base_template, tgt_suffix = tgt_template.rsplit(".", 1)
|
| tgt_suffix = "." + tgt_suffix
|
| candidates = dvals if isinstance(dvals[0], (list, tuple)) else [dvals]
|
|
|
| for candidate in candidates:
|
| src_suffix = "." + candidate[0].rsplit(".", 1)[1]
|
| if any(d.rsplit(".", 1)[1] != src_suffix.lstrip(".") for d in candidate):
|
| continue
|
| src_bases = [d.rsplit(".", 1)[0] for d in candidate]
|
| suffix_sets = [_collect_suffixes(base_keys, src_prefix + base) for base in src_bases]
|
| if any(not suffix_set for suffix_set in suffix_sets):
|
| continue
|
| common_suffixes = set.intersection(*suffix_sets)
|
| if not common_suffixes:
|
| continue
|
|
|
| explicit_suffixes = _TARGET_SUFFIXES.get(tgt_base_template, set())
|
| allow_extra = src_suffix == tgt_suffix
|
| if src_suffix in common_suffixes:
|
| suffixes = {src_suffix}
|
| if allow_extra and tgt_suffix == ".weight":
|
| suffixes |= (common_suffixes - explicit_suffixes)
|
| else:
|
| suffixes = common_suffixes & quant_suffixes
|
| if not suffixes:
|
| continue
|
|
|
| tgt_base = (
|
| tgt_base_template.replace("()", str(block_idx))
|
| if "()" in tgt_base_template
|
| else tgt_base_template
|
| )
|
| merge_in_features = tgt_base_template.endswith(".linear2") and len(src_bases) > 1
|
|
|
| for suffix in suffixes:
|
| values = [src.get(src_prefix + base + suffix) for base in src_bases]
|
| if len(values) == 1:
|
| merged = values[0]
|
| else:
|
| qweights = [
|
| src.get(src_prefix + base + ".qweight") if src.has(src_prefix + base + ".qweight") else None
|
| for base in src_bases
|
| ]
|
| norm_suffix = _normalize_suffix(suffix)
|
| if merge_in_features:
|
| merged = _merge_multi_in(values, norm_suffix, qweights)
|
| else:
|
| merged = _merge_multi(values, norm_suffix, qweights)
|
| if merged is None:
|
| continue
|
| out_suffix = tgt_suffix if suffix == src_suffix else _normalize_suffix(suffix)
|
| out[tgt_base + out_suffix] = merged
|
| break
|
|
|
| for block_idx in range(num_dual):
|
| prefix = f"transformer_blocks.{block_idx}."
|
| for tgt_key, dvals in DIFFUSERS_MAP.items():
|
| if not tgt_key.startswith("double_blocks."):
|
| continue
|
| _map_entry(prefix, tgt_key, dvals)
|
|
|
| for block_idx in range(num_single):
|
| prefix = f"single_transformer_blocks.{block_idx}."
|
| for tgt_key, dvals in DIFFUSERS_MAP.items():
|
| if not tgt_key.startswith("single_blocks."):
|
| continue
|
| _map_entry(prefix, tgt_key, dvals)
|
|
|
| block_idx = None
|
| for tgt_key, dvals in DIFFUSERS_MAP.items():
|
| if tgt_key.startswith(("double_blocks.", "single_blocks.")):
|
| continue
|
| _map_entry("", tgt_key, dvals)
|
|
|
| if "final_layer.adaLN_modulation.1.weight" in out:
|
| out["final_layer.adaLN_modulation.1.weight"] = _swap_scale_shift_matrix(
|
| out["final_layer.adaLN_modulation.1.weight"]
|
| )
|
| if "final_layer.adaLN_modulation.1.bias" in out:
|
| out["final_layer.adaLN_modulation.1.bias"] = _swap_scale_shift(
|
| out["final_layer.adaLN_modulation.1.bias"]
|
| )
|
|
|
| return out
|
|
|
|
|
| class DiffusersSource:
|
| """
|
| Uniform interface over:
|
| 1) Folder with index JSON + shards
|
| 2) Folder with exactly one .safetensors (no index)
|
| 3) Single .safetensors file
|
| Provides .has(key), .get(key)->Tensor, .base_keys (keys with 'model.' stripped for scanning)
|
| """
|
|
|
| POSSIBLE_PREFIXES = ["", "model."]
|
|
|
| def __init__(self, path: Path):
|
| p = Path(path)
|
| if p.is_dir():
|
|
|
| if (p / "transformer").is_dir():
|
| p = p / "transformer"
|
| self._init_from_dir(p)
|
| elif p.is_file() and p.suffix == ".safetensors":
|
| self._init_from_single_file(p)
|
| else:
|
| raise FileNotFoundError(f"Invalid path: {p}")
|
|
|
|
|
|
|
| @staticmethod
|
| def _strip_prefix(k: str) -> str:
|
| return k[6:] if k.startswith("model.") else k
|
|
|
| def _resolve(self, want: str):
|
| """
|
| Return the actual stored key matching `want` by trying known prefixes.
|
| """
|
| for pref in self.POSSIBLE_PREFIXES:
|
| k = pref + want
|
| if k in self._all_keys:
|
| return k
|
| return None
|
|
|
| def has(self, want: str) -> bool:
|
| return self._resolve(want) is not None
|
|
|
| def get(self, want: str) -> torch.Tensor:
|
| real_key = self._resolve(want)
|
| if real_key is None:
|
| raise KeyError(f"Missing key: {want}")
|
| return self._get_by_real_key(real_key).to("cpu")
|
|
|
| @property
|
| def base_keys(self):
|
|
|
| return [self._strip_prefix(k) for k in self._all_keys]
|
|
|
|
|
|
|
| def _init_from_single_file(self, file_path: Path):
|
| self._mode = "single"
|
| self._file = file_path
|
| self._handle = safe_open(file_path, framework="pt", device="cpu")
|
| self._all_keys = list(self._handle.keys())
|
|
|
| def _get_by_real_key(real_key: str):
|
| return self._handle.get_tensor(real_key)
|
|
|
| self._get_by_real_key = _get_by_real_key
|
|
|
| def _init_from_dir(self, dpath: Path):
|
| index_json = dpath / "diffusion_pytorch_model.safetensors.index.json"
|
| if index_json.exists():
|
| with open(index_json, "r", encoding="utf-8") as f:
|
| index = json.load(f)
|
| weight_map = index["weight_map"]
|
| self._mode = "sharded"
|
| self._dpath = dpath
|
| self._weight_map = {k: dpath / v for k, v in weight_map.items()}
|
| self._all_keys = list(self._weight_map.keys())
|
| self._open_handles = {}
|
|
|
| def _get_by_real_key(real_key: str):
|
| fpath = self._weight_map[real_key]
|
| h = self._open_handles.get(fpath)
|
| if h is None:
|
| h = safe_open(fpath, framework="pt", device="cpu")
|
| self._open_handles[fpath] = h
|
| return h.get_tensor(real_key)
|
|
|
| self._get_by_real_key = _get_by_real_key
|
| return
|
|
|
|
|
| files = sorted(dpath.glob("*.safetensors"))
|
| if len(files) != 1:
|
| raise FileNotFoundError(
|
| f"No index found and {dpath} does not contain exactly one .safetensors file."
|
| )
|
| self._init_from_single_file(files[0])
|
|
|
|
|
| def main():
|
| args = parse_args()
|
| src = DiffusersSource(Path(args.diffusers_path))
|
| orig = _convert_from_source(src, verbose=True)
|
|
|
|
|
| if args.fp8 or args.fp8_scaled:
|
| dtype = torch.float8_e4m3fn
|
| minv, maxv = torch.finfo(dtype).min, torch.finfo(dtype).max
|
|
|
| def stochastic_round_to(t):
|
| t = t.float().clamp(minv, maxv)
|
| lower = torch.floor(t * 256) / 256
|
| upper = torch.ceil(t * 256) / 256
|
| prob = torch.where(upper != lower, (t - lower) / (upper - lower), torch.zeros_like(t))
|
| rnd = torch.rand_like(t)
|
| out = torch.where(rnd < prob, upper, lower)
|
| return out.to(dtype)
|
|
|
| def scale_to_8bit(weight, target_max=416.0):
|
| absmax = weight.abs().max()
|
| scale = absmax / target_max if absmax > 0 else torch.tensor(1.0)
|
| scaled = (weight / scale).clamp(minv, maxv).to(dtype)
|
| return scaled, scale
|
|
|
| scales = {}
|
| for k in tqdm(list(orig.keys()), desc="Quantizing to fp8"):
|
| t = orig[k]
|
| if args.fp8:
|
| orig[k] = stochastic_round_to(t)
|
| else:
|
| if k.endswith(".weight") and t.dim() == 2:
|
| qt, s = scale_to_8bit(t)
|
| orig[k] = qt
|
| scales[k[:-len(".weight")] + ".scale_weight"] = s
|
| else:
|
| orig[k] = t.clamp(minv, maxv).to(dtype)
|
| if args.fp8_scaled:
|
| orig.update(scales)
|
| orig["scaled_fp8"] = torch.tensor([], dtype=dtype)
|
| else:
|
|
|
| for k in list(orig.keys()):
|
| orig[k] = orig[k].to(torch.bfloat16).cpu()
|
|
|
| out_path = Path(args.output_path)
|
| out_path.parent.mkdir(parents=True, exist_ok=True)
|
| meta = OrderedDict()
|
| meta["format"] = "pt"
|
| meta["modelspec.date"] = __import__("datetime").date.today().strftime("%Y-%m-%d")
|
| print(f"Saving transformer to: {out_path}")
|
| safetensors.torch.save_file(orig, str(out_path), metadata=meta)
|
| print("Done.")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|