| import gradio as gr | |
| from modules import script_callbacks | |
| from modules import sd_models, sd_vae | |
| from modules.ui import create_refresh_button | |
| from scripts import convert | |
| def gr_show(visible=True): | |
| return {"visible": visible, "__type__": "update"} | |
| def add_tab(): | |
| with gr.Blocks(analytics_enabled=False) as ui: | |
| with gr.Row(equal_height=True): | |
| with gr.Column(variant='panel'): | |
| gr.HTML(value="<p>Converted checkpoints will be saved in your <b>checkpoint</b> directory.</p>") | |
| with gr.Tabs(): | |
| with gr.TabItem(label='Single process'): | |
| with gr.Row(): | |
| model_name = gr.Dropdown(sd_models.checkpoint_tiles(), | |
| elem_id="model_converter_model_name", | |
| label="Model") | |
| create_refresh_button(model_name, sd_models.list_models, | |
| lambda: {"choices": sd_models.checkpoint_tiles()}, | |
| "refresh_checkpoint_Z") | |
| custom_name = gr.Textbox(label="Custom Name (Optional)") | |
| with gr.TabItem(label='Input file path'): | |
| with gr.Row(): | |
| model_path = gr.Textbox(label="model path") | |
| with gr.TabItem(label='Batch from directory'): | |
| with gr.Row(): | |
| input_directory = gr.Textbox(label="Input Directory") | |
| with gr.Row(): | |
| precision = gr.Radio(choices=["fp32", "fp16", "bf16"], value="fp16", label="Precision") | |
| m_type = gr.Radio(choices=["disabled", "no-ema", "ema-only"], value="disabled", label="Pruning Methods") | |
| with gr.Row(): | |
| checkpoint_formats = gr.CheckboxGroup(choices=["ckpt", "safetensors"], value=["safetensors"], label="Checkpoint Format") | |
| show_extra_options = gr.Checkbox(label="Show extra options", value=False) | |
| with gr.Row(): | |
| bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE") | |
| create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "model_converter_refresh_bake_in_vae") | |
| with gr.Row(): | |
| force_position_id = gr.Checkbox(label="Force CLIP position_id to int64 before convert", value=True) | |
| fix_clip = gr.Checkbox(label="Fix clip", value=False) | |
| delete_known_junk_data = gr.Checkbox(label="Delete known junk data", value=False) | |
| with gr.Row(visible=False) as extra_options: | |
| specific_part_conv = ["copy", "convert", "delete"] | |
| unet_conv = gr.Dropdown(specific_part_conv, value="convert", label="unet") | |
| text_encoder_conv = gr.Dropdown(specific_part_conv, value="convert", label="text encoder") | |
| vae_conv = gr.Dropdown(specific_part_conv, value="convert", label="vae") | |
| others_conv = gr.Dropdown(specific_part_conv, value="convert", label="others") | |
| model_converter_convert = gr.Button(elem_id="model_converter_convert", label="Convert", | |
| variant='primary') | |
| with gr.Column(variant='panel'): | |
| submit_result = gr.Textbox(elem_id="model_converter_result", show_label=False) | |
| show_extra_options.change( | |
| fn=lambda x: gr_show(x), | |
| inputs=[show_extra_options], | |
| outputs=[extra_options], | |
| ) | |
| model_converter_convert.click( | |
| fn=convert.convert_warp, | |
| inputs=[ | |
| model_name, | |
| model_path, | |
| input_directory, | |
| checkpoint_formats, | |
| precision, m_type, custom_name, | |
| bake_in_vae, | |
| unet_conv, | |
| text_encoder_conv, | |
| vae_conv, | |
| others_conv, | |
| fix_clip, | |
| force_position_id, | |
| delete_known_junk_data | |
| ], | |
| outputs=[submit_result] | |
| ) | |
| return [(ui, "Model Converter", "model_converter")] | |
| script_callbacks.on_ui_tabs(add_tab) | |