Spaces:
Sleeping
Sleeping
| import os | |
| from collections import OrderedDict | |
| from typing import * | |
| import torch | |
| def write_config(state_dict: Dict[str, Any], cfg: Dict[str, Any]): | |
| state_dict["config"] = [] | |
| for key, x in cfg.items(): | |
| state_dict["config"].append(x) | |
| state_dict["params"] = cfg | |
| def create_trained_model( | |
| weights: Dict[str, Any], | |
| version: Literal["v1", "v2"], | |
| sr: str, | |
| f0: bool, | |
| emb_name: str, | |
| emb_ch: int, | |
| emb_output_layer: int, | |
| epoch: int, | |
| speaker_info: Optional[dict[str, int]] | |
| ): | |
| state_dict = OrderedDict() | |
| state_dict["weight"] = {} | |
| for key in weights.keys(): | |
| if "enc_q" in key: | |
| continue | |
| state_dict["weight"][key] = weights[key].half() | |
| if sr == "40k": | |
| write_config( | |
| state_dict, | |
| { | |
| "spec_channels": 1025, | |
| "segment_size": 32, | |
| "inter_channels": 192, | |
| "hidden_channels": 192, | |
| "filter_channels": 768, | |
| "n_heads": 2, | |
| "n_layers": 6, | |
| "kernel_size": 3, | |
| "p_dropout": 0, | |
| "resblock": "1", | |
| "resblock_kernel_sizes": [3, 7, 11], | |
| "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], | |
| "upsample_rates": [10, 10, 2, 2], | |
| "upsample_initial_channel": 512, | |
| "upsample_kernel_sizes": [16, 16, 4, 4], | |
| "spk_embed_dim": 109 if speaker_info is None else len(speaker_info), | |
| "gin_channels": 256, | |
| "emb_channels": emb_ch, | |
| "sr": 40000, | |
| }, | |
| ) | |
| elif sr == "48k": | |
| write_config( | |
| state_dict, | |
| { | |
| "spec_channels": 1025, | |
| "segment_size": 32, | |
| "inter_channels": 192, | |
| "hidden_channels": 192, | |
| "filter_channels": 768, | |
| "n_heads": 2, | |
| "n_layers": 6, | |
| "kernel_size": 3, | |
| "p_dropout": 0, | |
| "resblock": "1", | |
| "resblock_kernel_sizes": [3, 7, 11], | |
| "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], | |
| "upsample_rates": [10, 6, 2, 2, 2], | |
| "upsample_initial_channel": 512, | |
| "upsample_kernel_sizes": [16, 16, 4, 4, 4], | |
| "spk_embed_dim": 109 if speaker_info is None else len(speaker_info), | |
| "gin_channels": 256, | |
| "emb_channels": emb_ch, | |
| "sr": 48000, | |
| }, | |
| ) | |
| elif sr == "32k": | |
| write_config( | |
| state_dict, | |
| { | |
| "spec_channels": 513, | |
| "segment_size": 32, | |
| "inter_channels": 192, | |
| "hidden_channels": 192, | |
| "filter_channels": 768, | |
| "n_heads": 2, | |
| "n_layers": 6, | |
| "kernel_size": 3, | |
| "p_dropout": 0, | |
| "resblock": "1", | |
| "resblock_kernel_sizes": [3, 7, 11], | |
| "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], | |
| "upsample_rates": [10, 4, 2, 2, 2], | |
| "upsample_initial_channel": 512, | |
| "upsample_kernel_sizes": [16, 16, 4, 4, 4], | |
| "spk_embed_dim": 109 if speaker_info is None else len(speaker_info), | |
| "gin_channels": 256, | |
| "emb_channels": emb_ch, | |
| "sr": 32000, | |
| }, | |
| ) | |
| state_dict["version"] = version | |
| state_dict["info"] = f"{epoch}epoch" | |
| state_dict["sr"] = sr | |
| state_dict["f0"] = 1 if f0 else 0 | |
| state_dict["embedder_name"] = emb_name | |
| state_dict["embedder_output_layer"] = emb_output_layer | |
| if not speaker_info is None: | |
| state_dict["speaker_info"] = {str(v): str(k) for k, v in speaker_info.items()} | |
| return state_dict | |
| def save( | |
| model, | |
| version: Literal["v1", "v2"], | |
| sr: str, | |
| f0: bool, | |
| emb_name: str, | |
| emb_ch: int, | |
| emb_output_layer: int, | |
| filepath: str, | |
| epoch: int, | |
| speaker_info: Optional[dict[str, int]] | |
| ): | |
| if hasattr(model, "module"): | |
| state_dict = model.module.state_dict() | |
| else: | |
| state_dict = model.state_dict() | |
| print(f"save: emb_name: {emb_name} {emb_ch}") | |
| state_dict = create_trained_model( | |
| state_dict, | |
| version, | |
| sr, | |
| f0, | |
| emb_name, | |
| emb_ch, | |
| emb_output_layer, | |
| epoch, | |
| speaker_info | |
| ) | |
| os.makedirs(os.path.dirname(filepath), exist_ok=True) | |
| torch.save(state_dict, filepath) | |