rvc_api / lib /rvc /checkpoints.py
aryo100's picture
first commit
b5a064f
import os
from collections import OrderedDict
from typing import *
import torch
def write_config(state_dict: Dict[str, Any], cfg: Dict[str, Any]):
state_dict["config"] = []
for key, x in cfg.items():
state_dict["config"].append(x)
state_dict["params"] = cfg
def create_trained_model(
weights: Dict[str, Any],
version: Literal["v1", "v2"],
sr: str,
f0: bool,
emb_name: str,
emb_ch: int,
emb_output_layer: int,
epoch: int,
speaker_info: Optional[dict[str, int]]
):
state_dict = OrderedDict()
state_dict["weight"] = {}
for key in weights.keys():
if "enc_q" in key:
continue
state_dict["weight"][key] = weights[key].half()
if sr == "40k":
write_config(
state_dict,
{
"spec_channels": 1025,
"segment_size": 32,
"inter_channels": 192,
"hidden_channels": 192,
"filter_channels": 768,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": 0,
"resblock": "1",
"resblock_kernel_sizes": [3, 7, 11],
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"upsample_rates": [10, 10, 2, 2],
"upsample_initial_channel": 512,
"upsample_kernel_sizes": [16, 16, 4, 4],
"spk_embed_dim": 109 if speaker_info is None else len(speaker_info),
"gin_channels": 256,
"emb_channels": emb_ch,
"sr": 40000,
},
)
elif sr == "48k":
write_config(
state_dict,
{
"spec_channels": 1025,
"segment_size": 32,
"inter_channels": 192,
"hidden_channels": 192,
"filter_channels": 768,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": 0,
"resblock": "1",
"resblock_kernel_sizes": [3, 7, 11],
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"upsample_rates": [10, 6, 2, 2, 2],
"upsample_initial_channel": 512,
"upsample_kernel_sizes": [16, 16, 4, 4, 4],
"spk_embed_dim": 109 if speaker_info is None else len(speaker_info),
"gin_channels": 256,
"emb_channels": emb_ch,
"sr": 48000,
},
)
elif sr == "32k":
write_config(
state_dict,
{
"spec_channels": 513,
"segment_size": 32,
"inter_channels": 192,
"hidden_channels": 192,
"filter_channels": 768,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": 0,
"resblock": "1",
"resblock_kernel_sizes": [3, 7, 11],
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"upsample_rates": [10, 4, 2, 2, 2],
"upsample_initial_channel": 512,
"upsample_kernel_sizes": [16, 16, 4, 4, 4],
"spk_embed_dim": 109 if speaker_info is None else len(speaker_info),
"gin_channels": 256,
"emb_channels": emb_ch,
"sr": 32000,
},
)
state_dict["version"] = version
state_dict["info"] = f"{epoch}epoch"
state_dict["sr"] = sr
state_dict["f0"] = 1 if f0 else 0
state_dict["embedder_name"] = emb_name
state_dict["embedder_output_layer"] = emb_output_layer
if not speaker_info is None:
state_dict["speaker_info"] = {str(v): str(k) for k, v in speaker_info.items()}
return state_dict
def save(
model,
version: Literal["v1", "v2"],
sr: str,
f0: bool,
emb_name: str,
emb_ch: int,
emb_output_layer: int,
filepath: str,
epoch: int,
speaker_info: Optional[dict[str, int]]
):
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
print(f"save: emb_name: {emb_name} {emb_ch}")
state_dict = create_trained_model(
state_dict,
version,
sr,
f0,
emb_name,
emb_ch,
emb_output_layer,
epoch,
speaker_info
)
os.makedirs(os.path.dirname(filepath), exist_ok=True)
torch.save(state_dict, filepath)