| | import os |
| | import tqdm |
| | import torch |
| | import safetensors.torch |
| | from torch import Tensor |
| | from modules import shared |
| | from modules import sd_models, sd_vae |
| |
|
| | |
| | dtypes_to_fp16 = {torch.float32, torch.float64, torch.bfloat16} |
| | dtypes_to_bf16 = {torch.float32, torch.float64, torch.float16} |
| |
|
| |
|
| | class MockModelInfo: |
| | def __init__(self, model_path: str) -> None: |
| | self.filepath = model_path |
| | self.filename: str = os.path.basename(model_path) |
| | self.model_name: str = self.filename.split(".")[0] |
| |
|
| |
|
| | def conv_fp16(t: Tensor): |
| | return t.half() if t.dtype in dtypes_to_fp16 else t |
| |
|
| |
|
| | def conv_bf16(t: Tensor): |
| | return t.bfloat16() if t.dtype in dtypes_to_bf16 else t |
| |
|
| |
|
| | def conv_full(t): |
| | return t |
| |
|
| |
|
| | _g_precision_func = { |
| | "full": conv_full, |
| | "fp32": conv_full, |
| | "fp16": conv_fp16, |
| | "bf16": conv_bf16, |
| | } |
| |
|
| |
|
| | def check_weight_type(k: str) -> str: |
| | if k.startswith("model.diffusion_model"): |
| | return "unet" |
| | elif k.startswith("first_stage_model"): |
| | return "vae" |
| | elif k.startswith("cond_stage_model"): |
| | return "clip" |
| | return "other" |
| |
|
| |
|
| | def load_model(path): |
| | if path.endswith(".safetensors"): |
| | m = safetensors.torch.load_file(path, device="cpu") |
| | else: |
| | m = torch.load(path, map_location="cpu") |
| | state_dict = m["state_dict"] if "state_dict" in m else m |
| | return state_dict |
| |
|
| |
|
| | def fix_model(model, fix_clip=False, force_position_id=False): |
| | |
| | nai_keys = { |
| | 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', |
| | 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', |
| | 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.' |
| | } |
| | position_id_key = "cond_stage_model.transformer.text_model.embeddings.position_ids" |
| | for k in list(model.keys()): |
| | for r in nai_keys: |
| | if type(k) == str and k.startswith(r): |
| | new_key = k.replace(r, nai_keys[r]) |
| | model[new_key] = model[k] |
| | del model[k] |
| | print(f"[Converter] Fixed novelai error key {k}") |
| | break |
| |
|
| | if force_position_id and position_id_key in model: |
| | model[position_id_key] = model[position_id_key].to(torch.int64) |
| |
|
| | if fix_clip: |
| | if position_id_key in model: |
| | correct = torch.Tensor([list(range(77))]).to(torch.int64) |
| | now = model[position_id_key].to(torch.int64) |
| |
|
| | broken = correct.ne(now) |
| | broken = [i for i in range(77) if broken[0][i]] |
| | if len(broken) != 0: |
| | model[position_id_key] = correct |
| | print(f"[Converter] Fixed broken clip\n{broken}") |
| | else: |
| | print("[Converter] Clip in this model is fine, skip fixing...") |
| | else: |
| | print("[Converter] Missing position id in model, try fixing...") |
| | model[position_id_key] = torch.Tensor([list(range(77))]).to(torch.int64) |
| |
|
| | return model |
| |
|
| |
|
| | def convert_warp( |
| | model_name, model_path, directory, |
| | *args |
| | ): |
| | if sum(map(bool, [model_name, model_path, directory])) != 1: |
| | print("[Converter] Check your inputs. Multiple input was set or missing input") |
| | return |
| |
|
| | if directory != "": |
| | if not os.path.exists(directory) or not os.path.isdir(directory): |
| | return "Error: path not exists or not dir" |
| |
|
| | files = [f for f in os.listdir(directory) if f.endswith(".ckpt") or f.endswith(".safetensors")] |
| |
|
| | if len(files) == 0: |
| | return "Error: cant found model in directory" |
| |
|
| | |
| | _args = list(args) |
| | _args[3] = "" |
| |
|
| | for m in files: |
| | do_convert(MockModelInfo(os.path.join(directory, m)), *_args) |
| |
|
| | elif model_path != "": |
| | if os.path.exists(model_path): |
| | return do_convert(MockModelInfo(model_path), *args) |
| |
|
| | elif model_name != "": |
| | model_info = sd_models.checkpoints_list[model_name] |
| | return do_convert(MockModelInfo(model_info.filename), *args) |
| |
|
| | else: |
| | return "Error: must choose a model" |
| |
|
| |
|
| | def do_convert(model_info: MockModelInfo, |
| | checkpoint_formats, |
| | precision, conv_type, custom_name, |
| | bake_in_vae, |
| | unet_conv, text_encoder_conv, vae_conv, others_conv, |
| | fix_clip, force_position_id, delete_known_junk_data): |
| | if len(checkpoint_formats) == 0: |
| | return "Error: at least choose one model save format" |
| |
|
| | extra_opt = { |
| | "unet": unet_conv, |
| | "clip": text_encoder_conv, |
| | "vae": vae_conv, |
| | "other": others_conv |
| | } |
| | shared.state.begin() |
| | shared.state.job = 'model-convert' |
| | shared.state.textinfo = f"Loading {model_info.filename}..." |
| | print(f"[Converter] Loading {model_info.filename}...") |
| |
|
| | ok = {} |
| | state_dict = load_model(model_info.filepath) |
| | fix_model(state_dict, fix_clip=fix_clip, force_position_id=force_position_id) |
| |
|
| | conv_func = _g_precision_func[precision] |
| |
|
| | def _hf(wk: str, t: Tensor): |
| | if not isinstance(t, Tensor): |
| | return |
| | weight_type = check_weight_type(wk) |
| | conv_t = extra_opt[weight_type] |
| | if conv_t == "convert": |
| | ok[wk] = conv_func(t) |
| | elif conv_t == "copy": |
| | ok[wk] = t |
| | elif conv_t == "delete": |
| | return |
| |
|
| | print("[Converter] Converting model...") |
| |
|
| | if conv_type == "ema-only": |
| | for k in tqdm.tqdm(state_dict): |
| | ema_k = "___" |
| | try: |
| | ema_k = "model_ema." + k[6:].replace(".", "") |
| | except: |
| | pass |
| | if ema_k in state_dict: |
| | _hf(k, state_dict[ema_k]) |
| | |
| | elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]: |
| | _hf(k, state_dict[k]) |
| | |
| | |
| | |
| | elif conv_type == "no-ema": |
| | for k, v in tqdm.tqdm(state_dict.items()): |
| | if "model_ema." not in k: |
| | _hf(k, v) |
| | else: |
| | for k, v in tqdm.tqdm(state_dict.items()): |
| | _hf(k, v) |
| |
|
| | if delete_known_junk_data: |
| | known_junk_data_prefix = [ |
| | "embedding_manager.embedder.", |
| | "lora_te_text_model", |
| | "control_model." |
| | ] |
| | need_delete = [] |
| | for key in ok.keys(): |
| | for jk in known_junk_data_prefix: |
| | if key.startswith(jk): |
| | need_delete.append(key) |
| |
|
| | for k in need_delete: |
| | del ok[k] |
| |
|
| | bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None) |
| | if bake_in_vae_filename is not None: |
| | print(f"[Converter] Baking in VAE from {bake_in_vae_filename}") |
| | vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu') |
| |
|
| | for k, v in vae_dict.items(): |
| | _hf(k, vae_dict[k]) |
| |
|
| | del vae_dict |
| |
|
| | output = "" |
| | ckpt_dir = os.path.dirname(model_info.filepath) |
| | save_name = f"{model_info.model_name}-{precision}" |
| | if conv_type != "disabled": |
| | save_name += f"-{conv_type}" |
| |
|
| | if fix_clip: |
| | save_name += f"-clip-fix" |
| |
|
| | if custom_name != "": |
| | save_name = custom_name |
| |
|
| | for fmt in checkpoint_formats: |
| | ext = ".safetensors" if fmt == "safetensors" else ".ckpt" |
| | _save_name = save_name + ext |
| |
|
| | save_path = os.path.join(ckpt_dir, _save_name) |
| | print(f"[Converter] Saving to {save_path}...") |
| |
|
| | if fmt == "safetensors": |
| | safetensors.torch.save_file(ok, save_path) |
| | else: |
| | torch.save({"state_dict": ok}, save_path) |
| | output += f"Checkpoint saved to {save_path}\n" |
| |
|
| | shared.state.end() |
| | return output[:-1] |
| |
|