|
|
| import warnings
|
| import logging
|
| import torch
|
| import gguf
|
| import re
|
| import os
|
|
|
| from .ops import GGMLTensor
|
| from .dequant import is_quantized, dequantize_tensor
|
|
|
| IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "hidream", "cosmos", "ltxv", "hyvid", "wan", "lumina2", "qwen_image"}
|
| TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl", "qwen3"}
|
| VIS_TYPE_LIST = {"clip-vision", "mmproj"}
|
|
|
| def get_orig_shape(reader, tensor_name):
|
| field_key = f"comfy.gguf.orig_shape.{tensor_name}"
|
| field = reader.get_field(field_key)
|
| if field is None:
|
| return None
|
|
|
| if len(field.types) != 2 or field.types[0] != gguf.GGUFValueType.ARRAY or field.types[1] != gguf.GGUFValueType.INT32:
|
| raise TypeError(f"Bad original shape metadata for {field_key}: Expected ARRAY of INT32, got {field.types}")
|
| return torch.Size(tuple(int(field.parts[part_idx][0]) for part_idx in field.data))
|
|
|
| def get_field(reader, field_name, field_type):
|
| field = reader.get_field(field_name)
|
| if field is None:
|
| return None
|
| elif field_type == str:
|
|
|
| if len(field.types) != 1 or field.types[0] != gguf.GGUFValueType.STRING:
|
| raise TypeError(f"Bad type for GGUF {field_name} key: expected string, got {field.types!r}")
|
| return str(field.parts[field.data[-1]], encoding="utf-8")
|
| elif field_type in [int, float, bool]:
|
| return field_type(field.parts[field.data[-1]])
|
| else:
|
| raise TypeError(f"Unknown field type {field_type}")
|
|
|
| def get_list_field(reader, field_name, field_type):
|
| field = reader.get_field(field_name)
|
| if field is None:
|
| return None
|
| elif field_type == str:
|
| return tuple(str(field.parts[part_idx], encoding="utf-8") for part_idx in field.data)
|
| elif field_type in [int, float, bool]:
|
| return tuple(field_type(field.parts[part_idx][0]) for part_idx in field.data)
|
| else:
|
| raise TypeError(f"Unknown field type {field_type}")
|
|
|
| def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=False, is_text_model=False):
|
| """
|
| Read state dict as fake tensors
|
| """
|
| reader = gguf.GGUFReader(path)
|
|
|
|
|
| has_prefix = False
|
| if handle_prefix is not None:
|
| prefix_len = len(handle_prefix)
|
| tensor_names = set(tensor.name for tensor in reader.tensors)
|
| has_prefix = any(s.startswith(handle_prefix) for s in tensor_names)
|
|
|
| tensors = []
|
| for tensor in reader.tensors:
|
| sd_key = tensor_name = tensor.name
|
| if has_prefix:
|
| if not tensor_name.startswith(handle_prefix):
|
| continue
|
| sd_key = tensor_name[prefix_len:]
|
| tensors.append((sd_key, tensor))
|
|
|
|
|
| compat = None
|
| arch_str = get_field(reader, "general.architecture", str)
|
| type_str = get_field(reader, "general.type", str)
|
| if arch_str in [None, "pig"]:
|
| if is_text_model:
|
| raise ValueError(f"This text model is incompatible with llama.cpp!\nConsider using the safetensors version\n({path})")
|
| compat = "sd.cpp" if arch_str is None else arch_str
|
|
|
| from .tools.convert import detect_arch
|
| try:
|
| arch_str = detect_arch(set(val[0] for val in tensors)).arch
|
| except Exception as e:
|
| raise ValueError(f"This model is not currently supported - ({e})")
|
| elif arch_str not in TXT_ARCH_LIST and is_text_model:
|
| if type_str not in VIS_TYPE_LIST:
|
| raise ValueError(f"Unexpected text model architecture type in GGUF file: {arch_str!r}")
|
| elif arch_str not in IMG_ARCH_LIST and not is_text_model:
|
| raise ValueError(f"Unexpected architecture type in GGUF file: {arch_str!r}")
|
|
|
| if compat:
|
| logging.warning(f"Warning: This gguf model file is loaded in compatibility mode '{compat}' [arch:{arch_str}]")
|
|
|
|
|
| state_dict = {}
|
| qtype_dict = {}
|
| for sd_key, tensor in tensors:
|
| tensor_name = tensor.name
|
|
|
|
|
|
|
| with warnings.catch_warnings():
|
| warnings.filterwarnings("ignore", message="The given NumPy array is not writable")
|
| torch_tensor = torch.from_numpy(tensor.data)
|
|
|
| shape = get_orig_shape(reader, tensor_name)
|
| if shape is None:
|
| shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
|
|
|
| if compat == "sd.cpp" and arch_str == "sdxl":
|
| if any([tensor_name.endswith(x) for x in (".proj_in.weight", ".proj_out.weight")]):
|
| while len(shape) > 2 and shape[-1] == 1:
|
| shape = shape[:-1]
|
|
|
|
|
| if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}:
|
| torch_tensor = torch_tensor.view(*shape)
|
| state_dict[sd_key] = GGMLTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape)
|
|
|
|
|
| tensor_type_str = getattr(tensor.tensor_type, "name", repr(tensor.tensor_type))
|
| qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1
|
|
|
|
|
| logging.info("gguf qtypes: " + ", ".join(f"{k} ({v})" for k, v in qtype_dict.items()))
|
|
|
|
|
| qsd = {k:v for k,v in state_dict.items() if is_quantized(v)}
|
| if len(qsd) > 0:
|
| max_key = max(qsd.keys(), key=lambda k: qsd[k].numel())
|
| state_dict[max_key].is_largest_weight = True
|
|
|
| if return_arch:
|
| return (state_dict, arch_str)
|
| return state_dict
|
|
|
|
|
| T5_SD_MAP = {
|
| "enc.": "encoder.",
|
| ".blk.": ".block.",
|
| "token_embd": "shared",
|
| "output_norm": "final_layer_norm",
|
| "attn_q": "layer.0.SelfAttention.q",
|
| "attn_k": "layer.0.SelfAttention.k",
|
| "attn_v": "layer.0.SelfAttention.v",
|
| "attn_o": "layer.0.SelfAttention.o",
|
| "attn_norm": "layer.0.layer_norm",
|
| "attn_rel_b": "layer.0.SelfAttention.relative_attention_bias",
|
| "ffn_up": "layer.1.DenseReluDense.wi_1",
|
| "ffn_down": "layer.1.DenseReluDense.wo",
|
| "ffn_gate": "layer.1.DenseReluDense.wi_0",
|
| "ffn_norm": "layer.1.layer_norm",
|
| }
|
|
|
| LLAMA_SD_MAP = {
|
| "blk.": "model.layers.",
|
| "attn_norm": "input_layernorm",
|
| "attn_q_norm.": "self_attn.q_norm.",
|
| "attn_k_norm.": "self_attn.k_norm.",
|
| "attn_v_norm.": "self_attn.v_norm.",
|
| "attn_q": "self_attn.q_proj",
|
| "attn_k": "self_attn.k_proj",
|
| "attn_v": "self_attn.v_proj",
|
| "attn_output": "self_attn.o_proj",
|
| "ffn_up": "mlp.up_proj",
|
| "ffn_down": "mlp.down_proj",
|
| "ffn_gate": "mlp.gate_proj",
|
| "ffn_norm": "post_attention_layernorm",
|
| "token_embd": "model.embed_tokens",
|
| "output_norm": "model.norm",
|
| "output.weight": "lm_head.weight",
|
| }
|
|
|
| CLIP_VISION_SD_MAP = {
|
| "mm.": "visual.merger.mlp.",
|
| "v.post_ln.": "visual.merger.ln_q.",
|
| "v.patch_embd": "visual.patch_embed.proj",
|
| "v.blk.": "visual.blocks.",
|
| "ffn_up": "mlp.up_proj",
|
| "ffn_down": "mlp.down_proj",
|
| "ffn_gate": "mlp.gate_proj",
|
| "attn_out.": "attn.proj.",
|
| "ln1.": "norm1.",
|
| "ln2.": "norm2.",
|
| }
|
|
|
| def sd_map_replace(raw_sd, key_map):
|
| sd = {}
|
| for k,v in raw_sd.items():
|
| for s,d in key_map.items():
|
| k = k.replace(s,d)
|
| sd[k] = v
|
| return sd
|
|
|
| def llama_permute(raw_sd, n_head, n_head_kv):
|
|
|
| sd = {}
|
| permute = lambda x,h: x.reshape(h, x.shape[0] // h // 2, 2, *x.shape[1:]).swapaxes(1, 2).reshape(x.shape)
|
| for k,v in raw_sd.items():
|
| if k.endswith(("q_proj.weight", "q_proj.bias")):
|
| v.data = permute(v.data, n_head)
|
| if k.endswith(("k_proj.weight", "k_proj.bias")):
|
| v.data = permute(v.data, n_head_kv)
|
| sd[k] = v
|
| return sd
|
|
|
| def strip_quant_suffix(name):
|
| pattern = r"[-_]?(?:ud-)?i?q[0-9]_[a-z0-9_\-]{1,8}$"
|
| match = re.search(pattern, name, re.IGNORECASE)
|
| if match:
|
| name = name[:match.start()]
|
| return name
|
|
|
| def gguf_mmproj_loader(path):
|
|
|
| logging.info("Attenpting to find mmproj file for text encoder...")
|
|
|
|
|
| tenc_fname = os.path.basename(path)
|
| tenc = os.path.splitext(tenc_fname)[0].lower()
|
| tenc = strip_quant_suffix(tenc)
|
|
|
|
|
| target = []
|
| root = os.path.dirname(path)
|
| for fname in os.listdir(root):
|
| name, ext = os.path.splitext(fname)
|
| if ext.lower() != ".gguf":
|
| continue
|
| if "mmproj" not in name.lower():
|
| continue
|
| if tenc in name.lower():
|
| target.append(fname)
|
|
|
| if len(target) == 0:
|
| logging.error(f"Error: Can't find mmproj file for '{tenc_fname}' (matching:'{tenc}')! Qwen-Image-Edit will be broken!")
|
| return {}
|
| if len(target) > 1:
|
| logging.error(f"Ambiguous mmproj for text encoder '{tenc_fname}', will use first match.")
|
|
|
| logging.info(f"Using mmproj '{target[0]}' for text encoder '{tenc_fname}'.")
|
| target = os.path.join(root, target[0])
|
| vsd = gguf_sd_loader(target, is_text_model=True)
|
|
|
|
|
| if "v.patch_embd.weight.1" in vsd:
|
| w1 = dequantize_tensor(vsd.pop("v.patch_embd.weight"), dtype=torch.float32)
|
| w2 = dequantize_tensor(vsd.pop("v.patch_embd.weight.1"), dtype=torch.float32)
|
| vsd["v.patch_embd.weight"] = torch.stack([w1, w2], dim=2)
|
|
|
|
|
| vsd = sd_map_replace(vsd, CLIP_VISION_SD_MAP)
|
|
|
|
|
| if "visual.blocks.0.attn_q.weight" in vsd:
|
| attns = {}
|
|
|
| for k,v in vsd.items():
|
| if any(x in k for x in ["attn_q", "attn_k", "attn_v"]):
|
| k_attn, k_name = k.rsplit(".attn_", 1)
|
| k_attn += ".attn.qkv." + k_name.split(".")[-1]
|
| if k_attn not in attns:
|
| attns[k_attn] = {}
|
| attns[k_attn][k_name] = dequantize_tensor(
|
| v, dtype=(torch.bfloat16 if is_quantized(v) else torch.float16)
|
| )
|
|
|
|
|
| for k,v in attns.items():
|
| suffix = k.split(".")[-1]
|
| vsd[k] = torch.cat([
|
| v[f"q.{suffix}"],
|
| v[f"k.{suffix}"],
|
| v[f"v.{suffix}"],
|
| ], dim=0)
|
| del attns
|
|
|
| return vsd
|
|
|
| def gguf_tokenizer_loader(path, temb_shape):
|
|
|
| logging.info("Attempting to recreate sentencepiece tokenizer from GGUF file metadata...")
|
| try:
|
| from sentencepiece import sentencepiece_model_pb2 as model
|
| except ImportError:
|
| raise ImportError("Please make sure sentencepiece and protobuf are installed.\npip install sentencepiece protobuf")
|
| spm = model.ModelProto()
|
|
|
| reader = gguf.GGUFReader(path)
|
|
|
| if get_field(reader, "tokenizer.ggml.model", str) == "t5":
|
| if temb_shape == (256384, 4096):
|
| spm.trainer_spec.model_type == 1
|
| else:
|
| raise NotImplementedError("Unknown model, can't set tokenizer!")
|
| else:
|
| raise NotImplementedError("Unknown model, can't set tokenizer!")
|
|
|
| spm.normalizer_spec.add_dummy_prefix = get_field(reader, "tokenizer.ggml.add_space_prefix", bool)
|
| spm.normalizer_spec.remove_extra_whitespaces = get_field(reader, "tokenizer.ggml.remove_extra_whitespaces", bool)
|
|
|
| tokens = get_list_field(reader, "tokenizer.ggml.tokens", str)
|
| scores = get_list_field(reader, "tokenizer.ggml.scores", float)
|
| toktypes = get_list_field(reader, "tokenizer.ggml.token_type", int)
|
|
|
| for idx, (token, score, toktype) in enumerate(zip(tokens, scores, toktypes)):
|
|
|
|
|
|
|
|
|
| piece = spm.SentencePiece()
|
| piece.piece = token
|
| piece.score = score
|
| piece.type = toktype
|
| spm.pieces.append(piece)
|
|
|
|
|
| spm.trainer_spec.byte_fallback = True
|
| spm.trainer_spec.vocab_size = len(tokens)
|
| spm.trainer_spec.max_sentence_length = 4096
|
| spm.trainer_spec.eos_id = get_field(reader, "tokenizer.ggml.eos_token_id", int)
|
| spm.trainer_spec.pad_id = get_field(reader, "tokenizer.ggml.padding_token_id", int)
|
|
|
| logging.info(f"Created tokenizer with vocab size of {len(spm.pieces)}")
|
| del reader
|
| return torch.ByteTensor(list(spm.SerializeToString()))
|
|
|
| def gguf_clip_loader(path):
|
| sd, arch = gguf_sd_loader(path, return_arch=True, is_text_model=True)
|
| if arch in {"t5", "t5encoder"}:
|
| temb_key = "token_embd.weight"
|
| if temb_key in sd and sd[temb_key].shape == (256384, 4096):
|
|
|
| sd["spiece_model"] = gguf_tokenizer_loader(path, sd[temb_key].shape)
|
|
|
| logging.warning(f"Dequantizing {temb_key} to prevent runtime OOM.")
|
| sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16)
|
| sd = sd_map_replace(sd, T5_SD_MAP)
|
| elif arch in {"llama", "qwen2vl", "qwen3"}:
|
|
|
| temb_key = "token_embd.weight"
|
| if temb_key in sd and sd[temb_key].shape[0] >= (64 * 1024):
|
|
|
| logging.warning(f"Dequantizing {temb_key} to prevent runtime OOM.")
|
| sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16)
|
| sd = sd_map_replace(sd, LLAMA_SD_MAP)
|
| if arch == "llama":
|
| sd = llama_permute(sd, 32, 8)
|
| if arch == "qwen2vl":
|
| vsd = gguf_mmproj_loader(path)
|
| sd.update(vsd)
|
| else:
|
| pass
|
| return sd
|
|
|