Spaces:
Running
on
Zero
Running
on
Zero
File size: 13,466 Bytes
708238a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 |
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
import os
import gguf
import torch
import logging
import argparse
from tqdm import tqdm
from safetensors.torch import load_file, save_file
QUANTIZATION_THRESHOLD = 1024
REARRANGE_THRESHOLD = 512
MAX_TENSOR_NAME_LENGTH = 127
MAX_TENSOR_DIMS = 4
class ModelTemplate:
arch = "invalid" # string describing architecture
shape_fix = False # whether to reshape tensors
keys_detect = [] # list of lists to match in state dict
keys_banned = [] # list of keys that should mark model as invalid for conversion
keys_hiprec = [] # list of keys that need to be kept in fp32 for some reason
keys_ignore = [] # list of strings to ignore keys by when found
def handle_nd_tensor(self, key, data):
raise NotImplementedError(f"Tensor detected that exceeds dims supported by C++ code! ({key} @ {data.shape})")
class ModelFlux(ModelTemplate):
arch = "flux"
keys_detect = [
("transformer_blocks.0.attn.norm_added_k.weight",),
("double_blocks.0.img_attn.proj.weight",),
]
keys_banned = ["transformer_blocks.0.attn.norm_added_k.weight",]
class ModelSD3(ModelTemplate):
arch = "sd3"
keys_detect = [
("transformer_blocks.0.attn.add_q_proj.weight",),
("joint_blocks.0.x_block.attn.qkv.weight",),
]
keys_banned = ["transformer_blocks.0.attn.add_q_proj.weight",]
class ModelAura(ModelTemplate):
arch = "aura"
keys_detect = [
("double_layers.3.modX.1.weight",),
("joint_transformer_blocks.3.ff_context.out_projection.weight",),
]
keys_banned = ["joint_transformer_blocks.3.ff_context.out_projection.weight",]
class ModelHiDream(ModelTemplate):
arch = "hidream"
keys_detect = [
(
"caption_projection.0.linear.weight",
"double_stream_blocks.0.block.ff_i.shared_experts.w3.weight"
)
]
keys_hiprec = [
# nn.parameter, can't load from BF16 ver
".ff_i.gate.weight",
"img_emb.emb_pos"
]
class CosmosPredict2(ModelTemplate):
arch = "cosmos"
keys_detect = [
(
"blocks.0.mlp.layer1.weight",
"blocks.0.adaln_modulation_cross_attn.1.weight",
)
]
keys_hiprec = ["pos_embedder"]
keys_ignore = ["_extra_state", "accum_"]
class ModelHyVid(ModelTemplate):
arch = "hyvid"
keys_detect = [
(
"double_blocks.0.img_attn_proj.weight",
"txt_in.individual_token_refiner.blocks.1.self_attn_qkv.weight",
)
]
def handle_nd_tensor(self, key, data):
# hacky but don't have any better ideas
path = f"./fix_5d_tensors_{self.arch}.safetensors" # TODO: somehow get a path here??
if os.path.isfile(path):
raise RuntimeError(f"5D tensor fix file already exists! {path}")
fsd = {key: torch.from_numpy(data)}
tqdm.write(f"5D key found in state dict! Manual fix required! - {key} {data.shape}")
save_file(fsd, path)
class ModelWan(ModelHyVid):
arch = "wan"
keys_detect = [
(
"blocks.0.self_attn.norm_q.weight",
"text_embedding.2.weight",
"head.modulation",
)
]
keys_hiprec = [
".modulation" # nn.parameter, can't load from BF16 ver
]
class ModelLTXV(ModelTemplate):
arch = "ltxv"
keys_detect = [
(
"adaln_single.emb.timestep_embedder.linear_2.weight",
"transformer_blocks.27.scale_shift_table",
"caption_projection.linear_2.weight",
)
]
keys_hiprec = [
"scale_shift_table" # nn.parameter, can't load from BF16 base quant
]
class ModelSDXL(ModelTemplate):
arch = "sdxl"
shape_fix = True
keys_detect = [
("down_blocks.0.downsamplers.0.conv.weight", "add_embedding.linear_1.weight",),
(
"input_blocks.3.0.op.weight", "input_blocks.6.0.op.weight",
"output_blocks.2.2.conv.weight", "output_blocks.5.2.conv.weight",
), # Non-diffusers
("label_emb.0.0.weight",),
]
class ModelSD1(ModelTemplate):
arch = "sd1"
shape_fix = True
keys_detect = [
("down_blocks.0.downsamplers.0.conv.weight",),
(
"input_blocks.3.0.op.weight", "input_blocks.6.0.op.weight", "input_blocks.9.0.op.weight",
"output_blocks.2.1.conv.weight", "output_blocks.5.2.conv.weight", "output_blocks.8.2.conv.weight"
), # Non-diffusers
]
class ModelLumina2(ModelTemplate):
arch = "lumina2"
keys_detect = [
("cap_embedder.1.weight", "context_refiner.0.attention.qkv.weight")
]
arch_list = [ModelFlux, ModelSD3, ModelAura, ModelHiDream, CosmosPredict2,
ModelLTXV, ModelHyVid, ModelWan, ModelSDXL, ModelSD1, ModelLumina2]
def is_model_arch(model, state_dict):
# check if model is correct
matched = False
invalid = False
for match_list in model.keys_detect:
if all(key in state_dict for key in match_list):
matched = True
invalid = any(key in state_dict for key in model.keys_banned)
break
assert not invalid, "Model architecture not allowed for conversion! (i.e. reference VS diffusers format)"
return matched
def detect_arch(state_dict):
model_arch = None
for arch in arch_list:
if is_model_arch(arch, state_dict):
model_arch = arch()
break
assert model_arch is not None, "Unknown model architecture!"
return model_arch
def parse_args():
parser = argparse.ArgumentParser(description="Generate F16 GGUF files from single UNET")
parser.add_argument("--src", required=True, help="Source model ckpt file.")
parser.add_argument("--dst", help="Output unet gguf file.")
args = parser.parse_args()
if not os.path.isfile(args.src):
parser.error("No input provided!")
return args
def strip_prefix(state_dict):
# prefix for mixed state dict
prefix = None
for pfx in ["model.diffusion_model.", "model."]:
if any([x.startswith(pfx) for x in state_dict.keys()]):
prefix = pfx
break
# prefix for uniform state dict
if prefix is None:
for pfx in ["net."]:
if all([x.startswith(pfx) for x in state_dict.keys()]):
prefix = pfx
break
# strip prefix if found
if prefix is not None:
logging.info(f"State dict prefix found: '{prefix}'")
sd = {}
for k, v in state_dict.items():
if prefix not in k:
continue
k = k.replace(prefix, "")
sd[k] = v
else:
logging.debug("State dict has no prefix")
sd = state_dict
return sd
def load_state_dict(path):
if any(path.endswith(x) for x in [".ckpt", ".pt", ".bin", ".pth"]):
state_dict = torch.load(path, map_location="cpu", weights_only=True)
for subkey in ["model", "module"]:
if subkey in state_dict:
state_dict = state_dict[subkey]
break
if len(state_dict) < 20:
raise RuntimeError(f"pt subkey load failed: {state_dict.keys()}")
else:
state_dict = load_file(path)
return strip_prefix(state_dict)
def handle_tensors(writer, state_dict, model_arch):
name_lengths = tuple(sorted(
((key, len(key)) for key in state_dict.keys()),
key=lambda item: item[1],
reverse=True,
))
if not name_lengths:
return
max_name_len = name_lengths[0][1]
if max_name_len > MAX_TENSOR_NAME_LENGTH:
bad_list = ", ".join(f"{key!r} ({namelen})" for key, namelen in name_lengths if namelen > MAX_TENSOR_NAME_LENGTH)
raise ValueError(f"Can only handle tensor names up to {MAX_TENSOR_NAME_LENGTH} characters. Tensors exceeding the limit: {bad_list}")
for key, data in tqdm(state_dict.items()):
old_dtype = data.dtype
if any(x in key for x in model_arch.keys_ignore):
tqdm.write(f"Filtering ignored key: '{key}'")
continue
if data.dtype == torch.bfloat16:
data = data.to(torch.float32).numpy()
# this is so we don't break torch 2.0.X
elif data.dtype in [getattr(torch, "float8_e4m3fn", "_invalid"), getattr(torch, "float8_e5m2", "_invalid")]:
data = data.to(torch.float16).numpy()
else:
data = data.numpy()
n_dims = len(data.shape)
data_shape = data.shape
if old_dtype == torch.bfloat16:
data_qtype = gguf.GGMLQuantizationType.BF16
# elif old_dtype == torch.float32:
# data_qtype = gguf.GGMLQuantizationType.F32
else:
data_qtype = gguf.GGMLQuantizationType.F16
# The max no. of dimensions that can be handled by the quantization code is 4
if len(data.shape) > MAX_TENSOR_DIMS:
model_arch.handle_nd_tensor(key, data)
continue # needs to be added back later
# get number of parameters (AKA elements) in this tensor
n_params = 1
for dim_size in data_shape:
n_params *= dim_size
if old_dtype in (torch.float32, torch.bfloat16):
if n_dims == 1:
# one-dimensional tensors should be kept in F32
# also speeds up inference due to not dequantizing
data_qtype = gguf.GGMLQuantizationType.F32
elif n_params <= QUANTIZATION_THRESHOLD:
# very small tensors
data_qtype = gguf.GGMLQuantizationType.F32
elif any(x in key for x in model_arch.keys_hiprec):
# tensors that require max precision
data_qtype = gguf.GGMLQuantizationType.F32
if (model_arch.shape_fix # NEVER reshape for models such as flux
and n_dims > 1 # Skip one-dimensional tensors
and n_params >= REARRANGE_THRESHOLD # Only rearrange tensors meeting the size requirement
and (n_params / 256).is_integer() # Rearranging only makes sense if total elements is divisible by 256
and not (data.shape[-1] / 256).is_integer() # Only need to rearrange if the last dimension is not divisible by 256
):
orig_shape = data.shape
data = data.reshape(n_params // 256, 256)
writer.add_array(f"comfy.gguf.orig_shape.{key}", tuple(int(dim) for dim in orig_shape))
try:
data = gguf.quants.quantize(data, data_qtype)
except (AttributeError, gguf.QuantError) as e:
tqdm.write(f"falling back to F16: {e}")
data_qtype = gguf.GGMLQuantizationType.F16
data = gguf.quants.quantize(data, data_qtype)
new_name = key # do we need to rename?
shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"
tqdm.write(f"{f'%-{max_name_len + 4}s' % f'{new_name}'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
writer.add_tensor(new_name, data, raw_dtype=data_qtype)
def convert_file(path, dst_path=None, interact=True, overwrite=False):
# load & run model detection logic
state_dict = load_state_dict(path)
model_arch = detect_arch(state_dict)
logging.info(f"* Architecture detected from input: {model_arch.arch}")
# detect & set dtype for output file
dtypes = [x.dtype for x in state_dict.values()]
dtypes = {x:dtypes.count(x) for x in set(dtypes)}
main_dtype = max(dtypes, key=dtypes.get)
if main_dtype == torch.bfloat16:
ftype_name = "BF16"
ftype_gguf = gguf.LlamaFileType.MOSTLY_BF16
# elif main_dtype == torch.float32:
# ftype_name = "F32"
# ftype_gguf = None
else:
ftype_name = "F16"
ftype_gguf = gguf.LlamaFileType.MOSTLY_F16
if dst_path is None:
dst_path = f"{os.path.splitext(path)[0]}-{ftype_name}.gguf"
elif "{ftype}" in dst_path: # lcpp logic
dst_path = dst_path.replace("{ftype}", ftype_name)
if os.path.isfile(dst_path) and not overwrite:
if interact:
input("Output exists enter to continue or ctrl+c to abort!")
else:
raise OSError("Output exists and overwriting is disabled!")
# handle actual file
writer = gguf.GGUFWriter(path=None, arch=model_arch.arch)
writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
if ftype_gguf is not None:
writer.add_file_type(ftype_gguf)
handle_tensors(writer, state_dict, model_arch)
writer.write_header_to_file(path=dst_path)
writer.write_kv_data_to_file()
writer.write_tensors_to_file(progress=True)
writer.close()
fix = f"./fix_5d_tensors_{model_arch.arch}.safetensors"
if os.path.isfile(fix):
logging.warning(f"\n### Warning! Fix file found at '{fix}'")
logging.warning(" you most likely need to run 'fix_5d_tensors.py' after quantization.")
return dst_path, model_arch
if __name__ == "__main__":
args = parse_args()
convert_file(args.src, args.dst)
|