File size: 2,409 Bytes
e69a9f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import sys
import argparse
import torch
from lycoris.utils import merge_loha, merge_locon
from lycoris.kohya_model_utils import (
    load_models_from_stable_diffusion_checkpoint,
    save_stable_diffusion_checkpoint,
    load_file
)
import gradio as gr


def merge_models(base_model, lycoris_model, output_name, is_v2, device, dtype, weight):
    base = load_models_from_stable_diffusion_checkpoint(is_v2, base_model)
    if lycoris_model.rsplit('.', 1)[-1] == 'safetensors':
        lyco = load_file(lycoris_model)
    else:
        lyco = torch.load(lycoris_model)

    algo = None
    for key in lyco:
        if 'hada' in key:
            algo = 'loha'
            break
        elif 'lora_up' in key:
            algo = 'lora'
            break
    else:
        raise NotImplementedError('Cannot find the algo for this lycoris model file.')

    dtype_str = dtype.replace('fp', 'float').replace('bf', 'bfloat')
    dtype = {
        'float': torch.float,
        'float16': torch.float16,
        'float32': torch.float32,
        'float64': torch.float64,
        'bfloat': torch.bfloat16,
        'bfloat16': torch.bfloat16,
    }.get(dtype_str, None)
    if dtype is None:
        raise ValueError(f'Cannot Find the dtype "{dtype}"')

    if algo == 'loha':
        merge_loha(base, lyco, weight, device)
    elif algo == 'lora':
        merge_locon(base, lyco, weight, device)

    save_stable_diffusion_checkpoint(
        is_v2, output_name,
        base[0], base[2],
        None, 0, 0, dtype,
        base[1]
    )

    return output_name


def main():
    iface = gr.Interface(
        fn=merge_models,
        inputs=[
            gr.inputs.Textbox(label="Base Model Path"),
            gr.inputs.Textbox(label="Lycoris Model Path"),
            gr.inputs.Textbox(label="Output Model Path", default='./out.pt'),
            gr.inputs.Checkbox(label="Is base model SD V2?", default=False),
            gr.inputs.Textbox(label="Device", default='cpu'),
            gr.inputs.Dropdown(choices=['float', 'float16', 'float32', 'float64', 'bfloat', 'bfloat16'], label="Dtype", default='float'),
            gr.inputs.Number(label="Weight", default=1.0)
        ],
        outputs=gr.outputs.Textbox(label="Merged Model Path"),
        title="Model Merger",
        description="Merge Lycoris and Stable Diffusion models",
    )

    iface.launch()


if __name__ == '__main__':
    main()