File size: 4,563 Bytes
0163a2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
from modules import script_callbacks
from modules import sd_models, sd_vae
from modules.ui import create_refresh_button
from scripts import convert


def gr_show(visible=True):
    return {"visible": visible, "__type__": "update"}


def add_tab():
    with gr.Blocks(analytics_enabled=False) as ui:
        with gr.Row(equal_height=True):
            with gr.Column(variant='panel'):
                gr.HTML(value="<p>Converted checkpoints will be saved in your <b>checkpoint</b> directory.</p>")
                with gr.Tabs():
                    with gr.TabItem(label='Single process'):
                        with gr.Row():
                            model_name = gr.Dropdown(sd_models.checkpoint_tiles(),
                                                     elem_id="model_converter_model_name",
                                                     label="Model")
                            create_refresh_button(model_name, sd_models.list_models,
                                                  lambda: {"choices": sd_models.checkpoint_tiles()},
                                                  "refresh_checkpoint_Z")
                        custom_name = gr.Textbox(label="Custom Name (Optional)")

                    with gr.TabItem(label='Input file path'):
                        with gr.Row():
                            model_path = gr.Textbox(label="model path")

                    with gr.TabItem(label='Batch from directory'):
                        with gr.Row():
                            input_directory = gr.Textbox(label="Input Directory")

                with gr.Row():
                    precision = gr.Radio(choices=["fp32", "fp16", "bf16"], value="fp16", label="Precision")
                    m_type = gr.Radio(choices=["disabled", "no-ema", "ema-only"], value="disabled", label="Pruning Methods")

                with gr.Row():
                    checkpoint_formats = gr.CheckboxGroup(choices=["ckpt", "safetensors"], value=["safetensors"], label="Checkpoint Format")
                    show_extra_options = gr.Checkbox(label="Show extra options", value=False)

                with gr.Row():
                    bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE")
                    create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "model_converter_refresh_bake_in_vae")

                with gr.Row():
                    force_position_id = gr.Checkbox(label="Force CLIP position_id to int64 before convert", value=True)
                    fix_clip = gr.Checkbox(label="Fix clip", value=False)
                    delete_known_junk_data = gr.Checkbox(label="Delete known junk data", value=False)

                with gr.Row(visible=False) as extra_options:
                    specific_part_conv = ["copy", "convert", "delete"]
                    unet_conv = gr.Dropdown(specific_part_conv, value="convert", label="unet")
                    text_encoder_conv = gr.Dropdown(specific_part_conv, value="convert", label="text encoder")
                    vae_conv = gr.Dropdown(specific_part_conv, value="convert", label="vae")
                    others_conv = gr.Dropdown(specific_part_conv, value="convert", label="others")

                model_converter_convert = gr.Button(elem_id="model_converter_convert", label="Convert",
                                                    variant='primary')

            with gr.Column(variant='panel'):
                submit_result = gr.Textbox(elem_id="model_converter_result", show_label=False)

            show_extra_options.change(
                fn=lambda x: gr_show(x),
                inputs=[show_extra_options],
                outputs=[extra_options],
            )

            model_converter_convert.click(
                fn=convert.convert_warp,
                inputs=[
                    model_name,
                    model_path,
                    input_directory,
                    checkpoint_formats,
                    precision, m_type, custom_name,
                    bake_in_vae,
                    unet_conv,
                    text_encoder_conv,
                    vae_conv,
                    others_conv,
                    fix_clip,
                    force_position_id,
                    delete_known_junk_data
                ],
                outputs=[submit_result]
            )

    return [(ui, "Model Converter", "model_converter")]


script_callbacks.on_ui_tabs(add_tab)