Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| from typing import * | |
| import gradio as gr | |
| import torch | |
| from modules import models | |
| from modules.merge import merge | |
| from modules.tabs.inference import inference_options_ui | |
| from modules.ui import Tab | |
| MERGE_METHODS = { | |
| "weight_sum": "Weight sum:A*(1-alpha)+B*alpha", | |
| "add_diff": "Add difference:A+(B-C)*alpha", | |
| } | |
| class Merge(Tab): | |
| def title(self): | |
| return "Merge" | |
| def sort(self): | |
| return 3 | |
| def ui(self, outlet): | |
| def merge_ckpt(model_a, model_b, model_c, weight_text, alpha, each_key, method): | |
| model_a = model_a if type(model_a) != list and model_a != "" else None | |
| model_b = model_b if type(model_b) != list and model_b != "" else None | |
| model_c = model_c if type(model_c) != list and model_c != "" else None | |
| if each_key: | |
| weights = json.loads(weight_text) | |
| else: | |
| weights = {} | |
| method = [k for k, v in MERGE_METHODS.items() if v == method][0] | |
| return merge( | |
| os.path.join(models.MODELS_DIR, "checkpoints", model_a), | |
| os.path.join(models.MODELS_DIR, "checkpoints", model_b), | |
| os.path.join(models.MODELS_DIR, "checkpoints", model_c) | |
| if model_c | |
| else None, | |
| alpha, | |
| weights, | |
| method, | |
| ) | |
| def merge_and_save( | |
| model_a, model_b, model_c, alpha, each_key, weight_text, method, out_name | |
| ): | |
| print(each_key) | |
| out_path = os.path.join(models.MODELS_DIR, "checkpoints", out_name) | |
| if os.path.exists(out_path): | |
| return "Model name already exists." | |
| merged = merge_ckpt( | |
| model_a, model_b, model_c, weight_text, alpha, each_key, method | |
| ) | |
| if not out_name.endswith(".pth"): | |
| out_name += ".pth" | |
| torch.save(merged, os.path.join(models.MODELS_DIR, "checkpoints", out_name)) | |
| return "Success" | |
| def merge_and_gen( | |
| model_a, | |
| model_b, | |
| model_c, | |
| alpha, | |
| each_key, | |
| weight_text, | |
| method, | |
| speaker_id, | |
| source_audio, | |
| embedder_name, | |
| embedding_output_layer, | |
| transpose, | |
| fo_curve_file, | |
| pitch_extraction_algo, | |
| auto_load_index, | |
| faiss_index_file, | |
| retrieval_feature_ratio, | |
| ): | |
| merged = merge_ckpt( | |
| model_a, model_b, model_c, weight_text, alpha, each_key, method | |
| ) | |
| model = models.VoiceConvertModel("merge", merged) | |
| audio = model.single( | |
| speaker_id, | |
| source_audio, | |
| embedder_name, | |
| embedding_output_layer, | |
| transpose, | |
| fo_curve_file, | |
| pitch_extraction_algo, | |
| auto_load_index, | |
| faiss_index_file, | |
| retrieval_feature_ratio, | |
| ) | |
| tgt_sr = model.tgt_sr | |
| del merged | |
| del model | |
| torch.cuda.empty_cache() | |
| return "Success", (tgt_sr, audio) | |
| def reload_model(): | |
| model_list = models.get_models() | |
| return ( | |
| gr.Dropdown.update(choices=model_list), | |
| gr.Dropdown.update(choices=model_list), | |
| gr.Dropdown.update(choices=model_list), | |
| ) | |
| def update_speaker_ids(model): | |
| if model == "": | |
| return gr.Slider.update( | |
| maximum=0, | |
| visible=False, | |
| ) | |
| model = torch.load( | |
| os.path.join(models.MODELS_DIR, "checkpoints", model), | |
| map_location="cpu", | |
| ) | |
| vc_model = models.VoiceConvertModel("merge", model) | |
| max = vc_model.n_spk | |
| del model | |
| del vc_model | |
| return gr.Slider.update( | |
| maximum=max, | |
| visible=True, | |
| ) | |
| with gr.Group(): | |
| with gr.Column(): | |
| with gr.Row(equal_height=False): | |
| model_a = gr.Dropdown(choices=models.get_models(), label="Model A") | |
| model_b = gr.Dropdown(choices=models.get_models(), label="Model B") | |
| model_c = gr.Dropdown(choices=models.get_models(), label="Model C") | |
| reload_model_button = gr.Button("♻️") | |
| reload_model_button.click( | |
| reload_model, outputs=[model_a, model_b, model_c] | |
| ) | |
| with gr.Row(equal_height=False): | |
| method = gr.Radio( | |
| label="Merge method", | |
| choices=list(MERGE_METHODS.values()), | |
| value="Weight sum:A*(1-alpha)+B*alpha", | |
| ) | |
| output_name = gr.Textbox(label="Output name") | |
| each_key = gr.Checkbox(label="Each key merge") | |
| with gr.Row(equal_height=False): | |
| base_alpha = gr.Slider( | |
| label="Base alpha", minimum=0, maximum=1, value=0.5, step=0.01 | |
| ) | |
| default_weights = {} | |
| weights = {} | |
| def create_weight_ui(name: str, *keys_list: List[List[str]]): | |
| with gr.Accordion(label=name, open=False): | |
| with gr.Row(equal_height=False): | |
| for keys in keys_list: | |
| with gr.Column(): | |
| for key in keys: | |
| default_weights[key] = 0.5 | |
| weights[key] = gr.Slider( | |
| label=key, | |
| minimum=0, | |
| maximum=1, | |
| step=0.01, | |
| value=0.5, | |
| ) | |
| with gr.Box(visible=False) as each_key_ui: | |
| with gr.Column(): | |
| create_weight_ui( | |
| "enc_p", | |
| [ | |
| "enc_p.encoder.attn_layers.0", | |
| "enc_p.encoder.attn_layers.1", | |
| "enc_p.encoder.attn_layers.2", | |
| "enc_p.encoder.attn_layers.3", | |
| "enc_p.encoder.attn_layers.4", | |
| "enc_p.encoder.attn_layers.5", | |
| "enc_p.encoder.norm_layers_1.0", | |
| "enc_p.encoder.norm_layers_1.1", | |
| "enc_p.encoder.norm_layers_1.2", | |
| "enc_p.encoder.norm_layers_1.3", | |
| "enc_p.encoder.norm_layers_1.4", | |
| "enc_p.encoder.norm_layers_1.5", | |
| ], | |
| [ | |
| "enc_p.encoder.ffn_layers.0", | |
| "enc_p.encoder.ffn_layers.1", | |
| "enc_p.encoder.ffn_layers.2", | |
| "enc_p.encoder.ffn_layers.3", | |
| "enc_p.encoder.ffn_layers.4", | |
| "enc_p.encoder.ffn_layers.5", | |
| "enc_p.encoder.norm_layers_2.0", | |
| "enc_p.encoder.norm_layers_2.1", | |
| "enc_p.encoder.norm_layers_2.2", | |
| "enc_p.encoder.norm_layers_2.3", | |
| "enc_p.encoder.norm_layers_2.4", | |
| "enc_p.encoder.norm_layers_2.5", | |
| ], | |
| [ | |
| "enc_p.emb_phone", | |
| "enc_p.emb_pitch", | |
| ], | |
| ) | |
| create_weight_ui( | |
| "dec", | |
| [ | |
| "dec.noise_convs.0", | |
| "dec.noise_convs.1", | |
| "dec.noise_convs.2", | |
| "dec.noise_convs.3", | |
| "dec.noise_convs.4", | |
| "dec.noise_convs.5", | |
| "dec.ups.0", | |
| "dec.ups.1", | |
| "dec.ups.2", | |
| "dec.ups.3", | |
| ], | |
| [ | |
| "dec.resblocks.0", | |
| "dec.resblocks.1", | |
| "dec.resblocks.2", | |
| "dec.resblocks.3", | |
| "dec.resblocks.4", | |
| "dec.resblocks.5", | |
| "dec.resblocks.6", | |
| "dec.resblocks.7", | |
| "dec.resblocks.8", | |
| "dec.resblocks.9", | |
| "dec.resblocks.10", | |
| "dec.resblocks.11", | |
| ], | |
| [ | |
| "dec.m_source.l_linear", | |
| "dec.conv_pre", | |
| "dec.conv_post", | |
| "dec.cond", | |
| ], | |
| ) | |
| create_weight_ui( | |
| "flow", | |
| [ | |
| "flow.flows.0", | |
| "flow.flows.1", | |
| "flow.flows.2", | |
| "flow.flows.3", | |
| "flow.flows.4", | |
| "flow.flows.5", | |
| "flow.flows.6", | |
| "emb_g.weight", | |
| ], | |
| ) | |
| with gr.Accordion(label="JSON", open=False): | |
| weights_text = gr.TextArea( | |
| value=json.dumps(default_weights), | |
| ) | |
| with gr.Accordion(label="Inference options", open=False): | |
| with gr.Row(equal_height=False): | |
| speaker_id = gr.Slider( | |
| minimum=0, | |
| maximum=2333, | |
| step=1, | |
| label="Speaker ID", | |
| value=0, | |
| visible=True, | |
| interactive=True, | |
| ) | |
| ( | |
| source_audio, | |
| _, | |
| transpose, | |
| embedder_name, | |
| embedding_output_layer, | |
| pitch_extraction_algo, | |
| auto_load_index, | |
| faiss_index_file, | |
| retrieval_feature_ratio, | |
| fo_curve_file, | |
| ) = inference_options_ui(show_out_dir=False) | |
| with gr.Row(equal_height=False): | |
| with gr.Column(): | |
| status = gr.Textbox(value="", label="Status") | |
| audio_output = gr.Audio(label="Output", interactive=False) | |
| with gr.Row(equal_height=False): | |
| merge_and_save_button = gr.Button( | |
| "Merge and save", variant="primary" | |
| ) | |
| merge_and_gen_button = gr.Button("Merge and gen", variant="primary") | |
| def each_key_on_change(each_key): | |
| return gr.update(visible=each_key) | |
| each_key.change( | |
| fn=each_key_on_change, | |
| inputs=[each_key], | |
| outputs=[each_key_ui], | |
| ) | |
| def update_weights_text(data): | |
| d = {} | |
| for key in weights.keys(): | |
| d[key] = data[weights[key]] | |
| return json.dumps(d) | |
| for w in weights.values(): | |
| w.change( | |
| fn=update_weights_text, | |
| inputs={*weights.values()}, | |
| outputs=[weights_text], | |
| ) | |
| merge_data = [ | |
| model_a, | |
| model_b, | |
| model_c, | |
| base_alpha, | |
| each_key, | |
| weights_text, | |
| method, | |
| ] | |
| inference_opts = [ | |
| speaker_id, | |
| source_audio, | |
| embedder_name, | |
| embedding_output_layer, | |
| transpose, | |
| fo_curve_file, | |
| pitch_extraction_algo, | |
| auto_load_index, | |
| faiss_index_file, | |
| retrieval_feature_ratio, | |
| ] | |
| merge_and_save_button.click( | |
| fn=merge_and_save, | |
| inputs=[ | |
| *merge_data, | |
| output_name, | |
| ], | |
| outputs=[status], | |
| ) | |
| merge_and_gen_button.click( | |
| fn=merge_and_gen, | |
| inputs=[ | |
| *merge_data, | |
| *inference_opts, | |
| ], | |
| outputs=[status, audio_output], | |
| ) | |
| model_a.change( | |
| update_speaker_ids, inputs=[model_a], outputs=[speaker_id] | |
| ) | |