rvc_api / modules /tabs /merge.py
aryo100's picture
first commit
b5a064f
import json
import os
from typing import *
import gradio as gr
import torch
from modules import models
from modules.merge import merge
from modules.tabs.inference import inference_options_ui
from modules.ui import Tab
MERGE_METHODS = {
"weight_sum": "Weight sum:A*(1-alpha)+B*alpha",
"add_diff": "Add difference:A+(B-C)*alpha",
}
class Merge(Tab):
def title(self):
return "Merge"
def sort(self):
return 3
def ui(self, outlet):
def merge_ckpt(model_a, model_b, model_c, weight_text, alpha, each_key, method):
model_a = model_a if type(model_a) != list and model_a != "" else None
model_b = model_b if type(model_b) != list and model_b != "" else None
model_c = model_c if type(model_c) != list and model_c != "" else None
if each_key:
weights = json.loads(weight_text)
else:
weights = {}
method = [k for k, v in MERGE_METHODS.items() if v == method][0]
return merge(
os.path.join(models.MODELS_DIR, "checkpoints", model_a),
os.path.join(models.MODELS_DIR, "checkpoints", model_b),
os.path.join(models.MODELS_DIR, "checkpoints", model_c)
if model_c
else None,
alpha,
weights,
method,
)
def merge_and_save(
model_a, model_b, model_c, alpha, each_key, weight_text, method, out_name
):
print(each_key)
out_path = os.path.join(models.MODELS_DIR, "checkpoints", out_name)
if os.path.exists(out_path):
return "Model name already exists."
merged = merge_ckpt(
model_a, model_b, model_c, weight_text, alpha, each_key, method
)
if not out_name.endswith(".pth"):
out_name += ".pth"
torch.save(merged, os.path.join(models.MODELS_DIR, "checkpoints", out_name))
return "Success"
def merge_and_gen(
model_a,
model_b,
model_c,
alpha,
each_key,
weight_text,
method,
speaker_id,
source_audio,
embedder_name,
embedding_output_layer,
transpose,
fo_curve_file,
pitch_extraction_algo,
auto_load_index,
faiss_index_file,
retrieval_feature_ratio,
):
merged = merge_ckpt(
model_a, model_b, model_c, weight_text, alpha, each_key, method
)
model = models.VoiceConvertModel("merge", merged)
audio = model.single(
speaker_id,
source_audio,
embedder_name,
embedding_output_layer,
transpose,
fo_curve_file,
pitch_extraction_algo,
auto_load_index,
faiss_index_file,
retrieval_feature_ratio,
)
tgt_sr = model.tgt_sr
del merged
del model
torch.cuda.empty_cache()
return "Success", (tgt_sr, audio)
def reload_model():
model_list = models.get_models()
return (
gr.Dropdown.update(choices=model_list),
gr.Dropdown.update(choices=model_list),
gr.Dropdown.update(choices=model_list),
)
def update_speaker_ids(model):
if model == "":
return gr.Slider.update(
maximum=0,
visible=False,
)
model = torch.load(
os.path.join(models.MODELS_DIR, "checkpoints", model),
map_location="cpu",
)
vc_model = models.VoiceConvertModel("merge", model)
max = vc_model.n_spk
del model
del vc_model
return gr.Slider.update(
maximum=max,
visible=True,
)
with gr.Group():
with gr.Column():
with gr.Row(equal_height=False):
model_a = gr.Dropdown(choices=models.get_models(), label="Model A")
model_b = gr.Dropdown(choices=models.get_models(), label="Model B")
model_c = gr.Dropdown(choices=models.get_models(), label="Model C")
reload_model_button = gr.Button("♻️")
reload_model_button.click(
reload_model, outputs=[model_a, model_b, model_c]
)
with gr.Row(equal_height=False):
method = gr.Radio(
label="Merge method",
choices=list(MERGE_METHODS.values()),
value="Weight sum:A*(1-alpha)+B*alpha",
)
output_name = gr.Textbox(label="Output name")
each_key = gr.Checkbox(label="Each key merge")
with gr.Row(equal_height=False):
base_alpha = gr.Slider(
label="Base alpha", minimum=0, maximum=1, value=0.5, step=0.01
)
default_weights = {}
weights = {}
def create_weight_ui(name: str, *keys_list: List[List[str]]):
with gr.Accordion(label=name, open=False):
with gr.Row(equal_height=False):
for keys in keys_list:
with gr.Column():
for key in keys:
default_weights[key] = 0.5
weights[key] = gr.Slider(
label=key,
minimum=0,
maximum=1,
step=0.01,
value=0.5,
)
with gr.Box(visible=False) as each_key_ui:
with gr.Column():
create_weight_ui(
"enc_p",
[
"enc_p.encoder.attn_layers.0",
"enc_p.encoder.attn_layers.1",
"enc_p.encoder.attn_layers.2",
"enc_p.encoder.attn_layers.3",
"enc_p.encoder.attn_layers.4",
"enc_p.encoder.attn_layers.5",
"enc_p.encoder.norm_layers_1.0",
"enc_p.encoder.norm_layers_1.1",
"enc_p.encoder.norm_layers_1.2",
"enc_p.encoder.norm_layers_1.3",
"enc_p.encoder.norm_layers_1.4",
"enc_p.encoder.norm_layers_1.5",
],
[
"enc_p.encoder.ffn_layers.0",
"enc_p.encoder.ffn_layers.1",
"enc_p.encoder.ffn_layers.2",
"enc_p.encoder.ffn_layers.3",
"enc_p.encoder.ffn_layers.4",
"enc_p.encoder.ffn_layers.5",
"enc_p.encoder.norm_layers_2.0",
"enc_p.encoder.norm_layers_2.1",
"enc_p.encoder.norm_layers_2.2",
"enc_p.encoder.norm_layers_2.3",
"enc_p.encoder.norm_layers_2.4",
"enc_p.encoder.norm_layers_2.5",
],
[
"enc_p.emb_phone",
"enc_p.emb_pitch",
],
)
create_weight_ui(
"dec",
[
"dec.noise_convs.0",
"dec.noise_convs.1",
"dec.noise_convs.2",
"dec.noise_convs.3",
"dec.noise_convs.4",
"dec.noise_convs.5",
"dec.ups.0",
"dec.ups.1",
"dec.ups.2",
"dec.ups.3",
],
[
"dec.resblocks.0",
"dec.resblocks.1",
"dec.resblocks.2",
"dec.resblocks.3",
"dec.resblocks.4",
"dec.resblocks.5",
"dec.resblocks.6",
"dec.resblocks.7",
"dec.resblocks.8",
"dec.resblocks.9",
"dec.resblocks.10",
"dec.resblocks.11",
],
[
"dec.m_source.l_linear",
"dec.conv_pre",
"dec.conv_post",
"dec.cond",
],
)
create_weight_ui(
"flow",
[
"flow.flows.0",
"flow.flows.1",
"flow.flows.2",
"flow.flows.3",
"flow.flows.4",
"flow.flows.5",
"flow.flows.6",
"emb_g.weight",
],
)
with gr.Accordion(label="JSON", open=False):
weights_text = gr.TextArea(
value=json.dumps(default_weights),
)
with gr.Accordion(label="Inference options", open=False):
with gr.Row(equal_height=False):
speaker_id = gr.Slider(
minimum=0,
maximum=2333,
step=1,
label="Speaker ID",
value=0,
visible=True,
interactive=True,
)
(
source_audio,
_,
transpose,
embedder_name,
embedding_output_layer,
pitch_extraction_algo,
auto_load_index,
faiss_index_file,
retrieval_feature_ratio,
fo_curve_file,
) = inference_options_ui(show_out_dir=False)
with gr.Row(equal_height=False):
with gr.Column():
status = gr.Textbox(value="", label="Status")
audio_output = gr.Audio(label="Output", interactive=False)
with gr.Row(equal_height=False):
merge_and_save_button = gr.Button(
"Merge and save", variant="primary"
)
merge_and_gen_button = gr.Button("Merge and gen", variant="primary")
def each_key_on_change(each_key):
return gr.update(visible=each_key)
each_key.change(
fn=each_key_on_change,
inputs=[each_key],
outputs=[each_key_ui],
)
def update_weights_text(data):
d = {}
for key in weights.keys():
d[key] = data[weights[key]]
return json.dumps(d)
for w in weights.values():
w.change(
fn=update_weights_text,
inputs={*weights.values()},
outputs=[weights_text],
)
merge_data = [
model_a,
model_b,
model_c,
base_alpha,
each_key,
weights_text,
method,
]
inference_opts = [
speaker_id,
source_audio,
embedder_name,
embedding_output_layer,
transpose,
fo_curve_file,
pitch_extraction_algo,
auto_load_index,
faiss_index_file,
retrieval_feature_ratio,
]
merge_and_save_button.click(
fn=merge_and_save,
inputs=[
*merge_data,
output_name,
],
outputs=[status],
)
merge_and_gen_button.click(
fn=merge_and_gen,
inputs=[
*merge_data,
*inference_opts,
],
outputs=[status, audio_output],
)
model_a.change(
update_speaker_ids, inputs=[model_a], outputs=[speaker_id]
)