| | |
| | |
| | |
| |
|
| | import argparse |
| | import json |
| | import os |
| | import time |
| | import torch |
| | from safetensors.torch import load_file, save_file |
| | from safetensors import safe_open |
| | from tqdm import tqdm |
| | from .library import flux_utils, sai_model_spec |
| | from .library.utils import MemoryEfficientSafeOpen |
| | from .library.utils import setup_logging |
| | from .networks import lora_flux |
| |
|
| | setup_logging() |
| | import logging |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | from comfy.utils import ProgressBar |
| | |
| | |
| |
|
| |
|
| | def save_to_file(file_name, state_dict, metadata, dtype): |
| | if dtype is not None: |
| | for key in list(state_dict.keys()): |
| | if type(state_dict[key]) == torch.Tensor: |
| | state_dict[key] = state_dict[key].to(dtype) |
| |
|
| | save_file(state_dict, file_name, metadata=metadata) |
| |
|
| |
|
| | def svd( |
| | model_org=None, |
| | model_tuned=None, |
| | save_to=None, |
| | dim=4, |
| | device=None, |
| | store_device='cpu', |
| | save_precision=None, |
| | clamp_quantile=0.99, |
| | min_diff=0.01, |
| | no_metadata=False, |
| | mem_eff_safe_open=False, |
| | ): |
| | def str_to_dtype(p): |
| | if p == "float": |
| | return torch.float |
| | if p == "fp16": |
| | return torch.float16 |
| | if p == "bf16": |
| | return torch.bfloat16 |
| | return None |
| |
|
| | calc_dtype = torch.float |
| | save_dtype = str_to_dtype(save_precision) |
| |
|
| | |
| | lora_weights = {} |
| | if not mem_eff_safe_open: |
| | |
| | open_fn = lambda fn: safe_open(fn, framework="pt") |
| | else: |
| | logger.info("Using memory efficient safe_open") |
| | open_fn = lambda fn: MemoryEfficientSafeOpen(fn) |
| |
|
| | with open_fn(model_org) as fo: |
| | |
| | keys = [] |
| | for key in fo.keys(): |
| | if not ("single_block" in key or "double_block" in key): |
| | continue |
| | if ".bias" in key: |
| | continue |
| | if "norm" in key: |
| | continue |
| | keys.append(key) |
| | comfy_pbar = ProgressBar(len(keys)) |
| | with open_fn(model_tuned) as ft: |
| | for key in tqdm(keys): |
| | |
| | value_o = fo.get_tensor(key) |
| | value_t = ft.get_tensor(key) |
| | mat = value_t.to(calc_dtype) - value_o.to(calc_dtype) |
| | del value_o, value_t |
| |
|
| | |
| | if device: |
| | mat = mat.to(device) |
| | out_dim, in_dim = mat.size()[0:2] |
| | rank = min(dim, in_dim, out_dim) |
| |
|
| | mat = mat.squeeze() |
| |
|
| | U, S, Vh = torch.linalg.svd(mat) |
| |
|
| | U = U[:, :rank] |
| | S = S[:rank] |
| | U = U @ torch.diag(S) |
| |
|
| | Vh = Vh[:rank, :] |
| |
|
| | dist = torch.cat([U.flatten(), Vh.flatten()]) |
| | hi_val = torch.quantile(dist, clamp_quantile) |
| | low_val = -hi_val |
| |
|
| | U = U.clamp(low_val, hi_val) |
| | Vh = Vh.clamp(low_val, hi_val) |
| |
|
| | U = U.to(store_device, dtype=save_dtype).contiguous() |
| | Vh = Vh.to(store_device, dtype=save_dtype).contiguous() |
| |
|
| | print(f"key: {key}, U: {U.size()}, Vh: {Vh.size()}") |
| | comfy_pbar.update(1) |
| | lora_weights[key] = (U, Vh) |
| | del mat, U, S, Vh |
| |
|
| | |
| | lora_sd = {} |
| | for key, (up_weight, down_weight) in lora_weights.items(): |
| | lora_name = key.replace(".weight", "").replace(".", "_") |
| | lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + lora_name |
| | lora_sd[lora_name + ".lora_up.weight"] = up_weight |
| | lora_sd[lora_name + ".lora_down.weight"] = down_weight |
| | lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0]) |
| |
|
| | |
| | net_kwargs = {} |
| | metadata = { |
| | "ss_v2": str(False), |
| | "ss_base_model_version": flux_utils.MODEL_VERSION_FLUX_V1, |
| | "ss_network_module": "networks.lora_flux", |
| | "ss_network_dim": str(dim), |
| | "ss_network_alpha": str(float(dim)), |
| | "ss_network_args": json.dumps(net_kwargs), |
| | } |
| |
|
| | if not no_metadata: |
| | title = os.path.splitext(os.path.basename(save_to))[0] |
| | sai_metadata = sai_model_spec.build_metadata(lora_sd, False, False, False, True, False, time.time(), title, flux="dev") |
| | metadata.update(sai_metadata) |
| |
|
| | save_to_file(save_to, lora_sd, metadata, save_dtype) |
| |
|
| | logger.info(f"LoRA weights saved to {save_to}") |
| | return save_to |
| |
|
| |
|
| | def setup_parser() -> argparse.ArgumentParser: |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--save_precision", |
| | type=str, |
| | default=None, |
| | choices=[None, "float", "fp16", "bf16"], |
| | help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat", |
| | ) |
| | parser.add_argument( |
| | "--model_org", |
| | type=str, |
| | default=None, |
| | required=True, |
| | help="Original model: safetensors file / 元モデル、safetensors", |
| | ) |
| | parser.add_argument( |
| | "--model_tuned", |
| | type=str, |
| | default=None, |
| | required=True, |
| | help="Tuned model, LoRA is difference of `original to tuned`: safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors", |
| | ) |
| | parser.add_argument( |
| | "--mem_eff_safe_open", |
| | action="store_true", |
| | help="use memory efficient safe_open. This is an experimental feature, use only when memory is not enough." |
| | " / メモリ効率の良いsafe_openを使用する。実装は実験的なものなので、メモリが足りない場合のみ使用してください。", |
| | ) |
| | parser.add_argument( |
| | "--save_to", |
| | type=str, |
| | default=None, |
| | required=True, |
| | help="destination file name: safetensors file / 保存先のファイル名、safetensors", |
| | ) |
| | parser.add_argument( |
| | "--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)" |
| | ) |
| | parser.add_argument( |
| | "--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う" |
| | ) |
| | parser.add_argument( |
| | "--clamp_quantile", |
| | type=float, |
| | default=0.99, |
| | help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99", |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | parser.add_argument( |
| | "--no_metadata", |
| | action="store_true", |
| | help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " |
| | + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", |
| | ) |
| | return parser |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = setup_parser() |
| |
|
| | args = parser.parse_args() |
| | svd(**vars(args)) |