9voltfan2009 commited on
Commit
4de7fc4
·
verified ·
1 Parent(s): 5ea0d8d

Delete main

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. main/app/app.py +0 -218
  2. main/app/based/utils.py +0 -1534
  3. main/app/parser.py +0 -339
  4. main/app/tabs/inference/inference.py +0 -596
  5. main/app/tabs/models/model.py +0 -465
  6. main/app/tabs/utils/utils.py +0 -305
  7. main/app/tensorboard.py +0 -30
  8. main/configs/config.json +0 -549
  9. main/configs/config.py +0 -90
  10. main/configs/decrypt.bin +0 -3
  11. main/configs/v1/32000.json +0 -46
  12. main/configs/v1/40000.json +0 -46
  13. main/configs/v1/48000.json +0 -46
  14. main/configs/v2/32000.json +0 -42
  15. main/configs/v2/40000.json +0 -42
  16. main/configs/v2/48000.json +0 -42
  17. main/inference/audio_effects.py +0 -180
  18. main/inference/audioldm2.py +0 -210
  19. main/inference/convert.py +0 -590
  20. main/inference/create_dataset.py +0 -230
  21. main/inference/create_index.py +0 -90
  22. main/inference/extract.py +0 -360
  23. main/inference/preprocess.py +0 -270
  24. main/inference/separator_music.py +0 -310
  25. main/inference/train.py +0 -990
  26. main/library/algorithm/commons.py +0 -60
  27. main/library/algorithm/modules.py +0 -60
  28. main/library/algorithm/mrf_hifigan.py +0 -150
  29. main/library/algorithm/onnx_export.py +0 -50
  30. main/library/algorithm/refinegan.py +0 -170
  31. main/library/algorithm/residuals.py +0 -140
  32. main/library/algorithm/separator.py +0 -320
  33. main/library/algorithm/stftpitchshift.py +0 -250
  34. main/library/algorithm/synthesizers.py +0 -490
  35. main/library/architectures/demucs_separator.py +0 -180
  36. main/library/architectures/fairseq.py +0 -1480
  37. main/library/architectures/mdx_separator.py +0 -320
  38. main/library/audioldm2/models.py +0 -330
  39. main/library/audioldm2/utils.py +0 -40
  40. main/library/predictors/CREPE.py +0 -210
  41. main/library/predictors/FCPE.py +0 -1000
  42. main/library/predictors/RMVPE.py +0 -260
  43. main/library/predictors/SWIPE.py +0 -140
  44. main/library/predictors/WORLD_WRAPPER.py +0 -90
  45. main/library/speaker_diarization/ECAPA_TDNN.py +0 -280
  46. main/library/speaker_diarization/audio.py +0 -170
  47. main/library/speaker_diarization/embedding.py +0 -90
  48. main/library/speaker_diarization/encoder.py +0 -250
  49. main/library/speaker_diarization/features.py +0 -520
  50. main/library/speaker_diarization/parameter_transfer.py +0 -120
main/app/app.py DELETED
@@ -1,218 +0,0 @@
1
- import os
2
- import sys
3
- import torch
4
- import shutil
5
- import librosa
6
- import logging
7
- import requests
8
- import subprocess
9
- import numpy as np
10
- import gradio as gr
11
- import soundfile as sf
12
- from time import sleep
13
- from multiprocessing import cpu_count
14
-
15
- sys.path.append(os.getcwd())
16
- from main.app.tabs.inference.inference import inference_tabs
17
- from main.app.tabs.models.model import model_tabs
18
- from main.app.tabs.utils.utils import utils_tabs
19
-
20
- from main.tools import huggingface
21
- from main.configs.config import Config
22
- from main.app.based.utils import *
23
-
24
- with gr.Blocks(title="Ultimate RVC Maker ⚡", theme=theme) as app:
25
- gr.HTML("<h1 style='text-align: center;'>Ultimate RVC Maker ⚡</h1>")
26
- gr.Markdown(
27
- f"""
28
- If you liked this HF Space you can give me a ❤️
29
-
30
- Try Ultimate RVC Maker WebUI using Colab [here](https://colab.research.google.com/github/TheNeodev/Notebook/blob/main/RVC-MAKER.ipynb)
31
- """
32
- )
33
- with gr.Tabs():
34
-
35
-
36
- with gr.TabItem("Inference"):
37
- inference_tabs()
38
- with gr.TabItem("Model Options"):
39
- model_tabs()
40
- with gr.TabItem(translations["separator_tab"], visible=configs.get("separator_tab", True)):
41
- gr.Markdown(f"## {translations['separator_tab']}")
42
- with gr.Row():
43
- gr.Markdown(translations["4_part"])
44
- with gr.Row():
45
- with gr.Column():
46
- with gr.Group():
47
- with gr.Row(equal_height=True):
48
- cleaner = gr.Checkbox(label=translations["clear_audio"], value=False, interactive=True, min_width=140)
49
- backing = gr.Checkbox(label=translations["separator_backing"], value=False, interactive=True, min_width=140)
50
- reverb = gr.Checkbox(label=translations["dereveb_audio"], value=False, interactive=True, min_width=140)
51
- backing_reverb = gr.Checkbox(label=translations["dereveb_backing"], value=False, interactive=False, min_width=140)
52
- denoise = gr.Checkbox(label=translations["denoise_mdx"], value=False, interactive=False, min_width=140)
53
- with gr.Row(equal_height=True):
54
- separator_model = gr.Dropdown(label=translations["separator_model"], value=uvr_model[0], choices=uvr_model, interactive=True)
55
- separator_backing_model = gr.Dropdown(label=translations["separator_backing_model"], value="Version-1", choices=["Version-1", "Version-2"], interactive=True, visible=backing.value)
56
-
57
- with gr.Row():
58
- with gr.Column():
59
- with gr.Group():
60
- with gr.Row(equal_height=True):
61
- shifts = gr.Slider(label=translations["shift"], info=translations["shift_info"], minimum=1, maximum=20, value=2, step=1, interactive=True)
62
- segment_size = gr.Slider(label=translations["segments_size"], info=translations["segments_size_info"], minimum=32, maximum=3072, value=256, step=32, interactive=True)
63
- with gr.Row():
64
- mdx_batch_size = gr.Slider(label=translations["batch_size"], info=translations["mdx_batch_size_info"], minimum=1, maximum=64, value=1, step=1, interactive=True, visible=backing.value or reverb.value or separator_model.value in mdx_model)
65
- with gr.Column():
66
- with gr.Group():
67
- with gr.Row():
68
- overlap = gr.Radio(label=translations["overlap"], info=translations["overlap_info"], choices=["0.25", "0.5", "0.75", "0.99"], value="0.25", interactive=True)
69
- with gr.Row():
70
- mdx_hop_length = gr.Slider(label="Hop length", info=translations["hop_length_info"], minimum=1, maximum=8192, value=1024, step=1, interactive=True, visible=backing.value or reverb.value or separator_model.value in mdx_model)
71
- with gr.Column():
72
- with gr.Row():
73
- clean_strength = gr.Slider(label=translations["clean_strength"], info=translations["clean_strength_info"], minimum=0, maximum=1, value=0.5, step=0.1, interactive=True, visible=cleaner.value)
74
- sample_rate1 = gr.Slider(minimum=8000, maximum=96000, step=1, value=44100, label=translations["sr"], info=translations["sr_info"], interactive=True)
75
- input = gr.File(label=translations["drop_audio"], file_types=[".wav", ".mp3", ".flac", ".ogg", ".opus", ".m4a", ".mp4", ".aac", ".alac", ".wma", ".aiff", ".webm", ".ac3"])
76
- audio_input = gr.Audio(show_download_button=True, interactive=False, label=translations["input_audio"])
77
- with gr.Column():
78
- with gr.Accordion(translations["use_url"], open=False):
79
- url = gr.Textbox(label=translations["url_audio"], value="", placeholder="https://www.youtube.com/...", scale=6)
80
- download_button = gr.Button(translations["downloads"])
81
- with gr.Column():
82
- with gr.Accordion(translations["input_output"], open=False):
83
- format = gr.Radio(label=translations["export_format"], info=translations["export_info"], choices=["wav", "mp3", "flac", "ogg", "opus", "m4a", "mp4", "aac", "alac", "wma", "aiff", "webm", "ac3"], value="wav", interactive=True)
84
- input_audio = gr.Dropdown(label=translations["audio_path"], value="", choices=paths_for_files, allow_custom_value=True, interactive=True)
85
- refesh_separator = gr.Button(translations["refesh"])
86
- output_separator = gr.Textbox(label=translations["output_folder"], value="audios", placeholder="audios", info=translations["output_folder_info"], interactive=True)
87
- separator_button = gr.Button(translations["separator_tab"], variant="primary")
88
- with gr.Row():
89
- gr.Markdown(translations["output_separator"])
90
- with gr.Row():
91
- instruments_audio = gr.Audio(show_download_button=True, interactive=False, label=translations["instruments"])
92
- original_vocals = gr.Audio(show_download_button=True, interactive=False, label=translations["original_vocal"])
93
- main_vocals = gr.Audio(show_download_button=True, interactive=False, label=translations["main_vocal"], visible=backing.value)
94
- backing_vocals = gr.Audio(show_download_button=True, interactive=False, label=translations["backing_vocal"], visible=backing.value)
95
- with gr.Row():
96
- separator_model.change(fn=lambda a, b, c: [visible(a or b or c in mdx_model), visible(a or b or c in mdx_model), valueFalse_interactive(a or b or c in mdx_model), visible(c not in mdx_model)], inputs=[backing, reverb, separator_model], outputs=[mdx_batch_size, mdx_hop_length, denoise, shifts])
97
- backing.change(fn=lambda a, b, c: [visible(a or b or c in mdx_model), visible(a or b or c in mdx_model), valueFalse_interactive(a or b or c in mdx_model), visible(a), visible(a), visible(a), valueFalse_interactive(a and b)], inputs=[backing, reverb, separator_model], outputs=[mdx_batch_size, mdx_hop_length, denoise, separator_backing_model, main_vocals, backing_vocals, backing_reverb])
98
- reverb.change(fn=lambda a, b, c: [visible(a or b or c in mdx_model), visible(a or b or c in mdx_model), valueFalse_interactive(a or b or c in mdx_model), valueFalse_interactive(a and b)], inputs=[backing, reverb, separator_model], outputs=[mdx_batch_size, mdx_hop_length, denoise, backing_reverb])
99
- with gr.Row():
100
- input_audio.change(fn=lambda audio: audio if os.path.isfile(audio) else None, inputs=[input_audio], outputs=[audio_input])
101
- cleaner.change(fn=visible, inputs=[cleaner], outputs=[clean_strength])
102
- with gr.Row():
103
- input.upload(fn=lambda audio_in: shutil.move(audio_in.name, os.path.join("audios")), inputs=[input], outputs=[input_audio])
104
- refesh_separator.click(fn=change_audios_choices, inputs=[input_audio], outputs=[input_audio])
105
- with gr.Row():
106
- download_button.click(
107
- fn=download_url,
108
- inputs=[url],
109
- outputs=[input_audio, audio_input, url],
110
- api_name='download_url'
111
- )
112
- separator_button.click(
113
- fn=separator_music,
114
- inputs=[
115
- input_audio,
116
- output_separator,
117
- format,
118
- shifts,
119
- segment_size,
120
- overlap,
121
- cleaner,
122
- clean_strength,
123
- denoise,
124
- separator_model,
125
- separator_backing_model,
126
- backing,
127
- reverb,
128
- backing_reverb,
129
- mdx_hop_length,
130
- mdx_batch_size,
131
- sample_rate1
132
- ],
133
- outputs=[original_vocals, instruments_audio, main_vocals, backing_vocals],
134
- api_name='separator_music'
135
- )
136
- utils_tabs()
137
- with gr.TabItem(translations["settings"], visible=configs.get("settings_tab", True)):
138
- gr.Markdown(translations["settings_markdown"])
139
- with gr.Row():
140
- gr.Markdown(translations["settings_markdown_2"])
141
- with gr.Row():
142
- toggle_button = gr.Button(translations["change_light_dark"], variant="secondary", scale=2)
143
- with gr.Row():
144
- with gr.Column():
145
- language_dropdown = gr.Dropdown(label=translations["lang"], interactive=True, info=translations["lang_restart"], choices=configs.get("support_language", "vi-VN"), value=language)
146
- change_lang = gr.Button(translations["change_lang"], variant="primary", scale=2)
147
- with gr.Column():
148
- theme_dropdown = gr.Dropdown(label=translations["theme"], interactive=True, info=translations["theme_restart"], choices=configs.get("themes", theme), value=theme, allow_custom_value=True)
149
- changetheme = gr.Button(translations["theme_button"], variant="primary", scale=2)
150
- with gr.Row():
151
- with gr.Column():
152
- fp_choice = gr.Radio(choices=["fp16","fp32"], value="fp16" if configs.get("fp16", False) else "fp32", label=translations["precision"], info=translations["precision_info"], interactive=True)
153
- fp_button = gr.Button(translations["update_precision"], variant="secondary", scale=2)
154
- with gr.Column():
155
- font_choice = gr.Textbox(label=translations["font"], info=translations["font_info"], value=font, interactive=True)
156
- font_button = gr.Button(translations["change_font"])
157
- with gr.Row():
158
- with gr.Column():
159
- with gr.Accordion(translations["stop"], open=False):
160
- separate_stop = gr.Button(translations["stop_separate"])
161
- convert_stop = gr.Button(translations["stop_convert"])
162
- create_dataset_stop = gr.Button(translations["stop_create_dataset"])
163
- audioldm2_stop = gr.Button(translations["stop_audioldm2"])
164
- with gr.Accordion(translations["stop_training"], open=False):
165
- model_name_stop = gr.Textbox(label=translations["modelname"], info=translations["training_model_name"], value="", placeholder=translations["modelname"], interactive=True)
166
- preprocess_stop = gr.Button(translations["stop_preprocess"])
167
- extract_stop = gr.Button(translations["stop_extract"])
168
- train_stop = gr.Button(translations["stop_training"])
169
- with gr.Row():
170
- toggle_button.click(fn=None, js="() => {document.body.classList.toggle('dark')}")
171
- fp_button.click(fn=change_fp, inputs=[fp_choice], outputs=[fp_choice])
172
- with gr.Row():
173
- change_lang.click(fn=change_language, inputs=[language_dropdown], outputs=[])
174
- changetheme.click(fn=change_theme, inputs=[theme_dropdown], outputs=[])
175
- font_button.click(fn=change_font, inputs=[font_choice], outputs=[])
176
- with gr.Row():
177
- change_lang.click(fn=None, js="setTimeout(function() {location.reload()}, 15000)", inputs=[], outputs=[])
178
- changetheme.click(fn=None, js="setTimeout(function() {location.reload()}, 15000)", inputs=[], outputs=[])
179
- font_button.click(fn=None, js="setTimeout(function() {location.reload()}, 15000)", inputs=[], outputs=[])
180
- with gr.Row():
181
- separate_stop.click(fn=lambda: stop_pid("separate_pid", None, False), inputs=[], outputs=[])
182
- convert_stop.click(fn=lambda: stop_pid("convert_pid", None, False), inputs=[], outputs=[])
183
- create_dataset_stop.click(fn=lambda: stop_pid("create_dataset_pid", None, False), inputs=[], outputs=[])
184
- with gr.Row():
185
- preprocess_stop.click(fn=lambda model_name_stop: stop_pid("preprocess_pid", model_name_stop, False), inputs=[model_name_stop], outputs=[])
186
- extract_stop.click(fn=lambda model_name_stop: stop_pid("extract_pid", model_name_stop, False), inputs=[model_name_stop], outputs=[])
187
- train_stop.click(fn=lambda model_name_stop: stop_pid("train_pid", model_name_stop, True), inputs=[model_name_stop], outputs=[])
188
- with gr.Row():
189
- audioldm2_stop.click(fn=lambda: stop_pid("audioldm2_pid", None, False), inputs=[], outputs=[])
190
-
191
-
192
- with gr.Row():
193
- gr.Markdown(translations["terms_of_use"])
194
- gr.Markdown(translations["exemption"])
195
-
196
- logger.info(translations["start_app"])
197
- logger.info(translations["set_lang"].format(lang=language))
198
-
199
- port = configs.get("app_port", 7860)
200
-
201
- for i in range(configs.get("num_of_restart", 5)):
202
- try:
203
- app.queue().launch(
204
- favicon_path=os.path.join("assets", "ico.png"),
205
- server_name=configs.get("server_name", "0.0.0.0"),
206
- server_port=port,
207
- show_error=configs.get("app_show_error", False),
208
- inbrowser="--open" in sys.argv,
209
- share="--share" in sys.argv,
210
- allowed_paths=allow_disk
211
- )
212
- break
213
- except OSError:
214
- logger.debug(translations["port"].format(port=port))
215
- port -= 1
216
- except Exception as e:
217
- logger.error(translations["error_occurred"].format(e=e))
218
- sys.exit(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/app/based/utils.py DELETED
@@ -1,1534 +0,0 @@
1
- import os
2
- import re
3
- import ssl
4
- import sys
5
- import json
6
- import torch
7
- import codecs
8
- import shutil
9
- import asyncio
10
- import librosa
11
- import logging
12
- import datetime
13
- import platform
14
- import requests
15
- import warnings
16
- import threading
17
- import subprocess
18
- import logging.handlers
19
-
20
- import numpy as np
21
- import gradio as gr
22
- import pandas as pd
23
- import soundfile as sf
24
-
25
- from time import sleep
26
- from multiprocessing import cpu_count
27
-
28
- sys.path.append(os.getcwd())
29
-
30
- from main.tools import huggingface
31
- from main.configs.config import Config
32
-
33
- ssl._create_default_https_context = ssl._create_unverified_context
34
- logger = logging.getLogger(__name__)
35
- logger.propagate = False
36
-
37
- if logger.hasHandlers(): logger.handlers.clear()
38
- else:
39
- console_handler = logging.StreamHandler()
40
- console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
41
- console_handler.setFormatter(console_formatter)
42
- console_handler.setLevel(logging.INFO)
43
- file_handler = logging.handlers.RotatingFileHandler(os.path.join("assets", "logs", "app.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
44
- file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
45
- file_handler.setFormatter(file_formatter)
46
- file_handler.setLevel(logging.DEBUG)
47
- logger.addHandler(console_handler)
48
- logger.addHandler(file_handler)
49
- logger.setLevel(logging.DEBUG)
50
-
51
- warnings.filterwarnings("ignore")
52
- for l in ["httpx", "gradio", "uvicorn", "httpcore", "urllib3"]:
53
- logging.getLogger(l).setLevel(logging.ERROR)
54
-
55
- config = Config()
56
- python = sys.executable
57
- translations = config.translations
58
- configs_json = os.path.join("main", "configs", "config.json")
59
- configs = json.load(open(configs_json, "r"))
60
-
61
- os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1"
62
- os.environ["TORCH_FORCE_WEIGHTS_ONLY_LOAD"] = "0"
63
-
64
- if config.device in ["cpu", "mps"] and configs.get("fp16", False):
65
- logger.warning(translations["fp16_not_support"])
66
- configs["fp16"] = config.is_half = False
67
- with open(configs_json, "w") as f:
68
- json.dump(configs, f, indent=4)
69
-
70
- models, model_options = {}, {}
71
-
72
- method_f0 = ["mangio-crepe-full", "crepe-full", "fcpe", "rmvpe", "harvest", "pyin"]
73
- method_f0_full = ["pm", "dio", "mangio-crepe-tiny", "mangio-crepe-small", "mangio-crepe-medium", "mangio-crepe-large", "mangio-crepe-full", "crepe-tiny", "crepe-small", "crepe-medium", "crepe-large", "crepe-full", "fcpe", "fcpe-legacy", "rmvpe", "rmvpe-legacy", "harvest", "yin", "pyin", "swipe"]
74
-
75
- embedders_mode = ["fairseq", "onnx", "transformers", "spin"]
76
- embedders_model = ["contentvec_base", "hubert_base", "japanese_hubert_base", "korean_hubert_base", "chinese_hubert_base", "portuguese_hubert_base", "custom"]
77
-
78
- paths_for_files = sorted([os.path.abspath(os.path.join(root, f)) for root, _, files in os.walk("audios") for f in files if os.path.splitext(f)[1].lower() in (".wav", ".mp3", ".flac", ".ogg", ".opus", ".m4a", ".mp4", ".aac", ".alac", ".wma", ".aiff", ".webm", ".ac3")])
79
- model_name, index_path, delete_index = sorted(list(model for model in os.listdir(os.path.join("assets", "weights")) if model.endswith((".pth", ".onnx")) and not model.startswith("G_") and not model.startswith("D_"))), sorted([os.path.join(root, name) for root, _, files in os.walk(os.path.join("assets", "logs"), topdown=False) for name in files if name.endswith(".index") and "trained" not in name]), sorted([os.path.join("assets", "logs", f) for f in os.listdir(os.path.join("assets", "logs")) if "mute" not in f and os.path.isdir(os.path.join("assets", "logs", f))])
80
- pretrainedD, pretrainedG, Allpretrained = ([model for model in os.listdir(os.path.join("assets", "models", "pretrained_custom")) if model.endswith(".pth") and "D" in model], [model for model in os.listdir(os.path.join("assets", "models", "pretrained_custom")) if model.endswith(".pth") and "G" in model], [os.path.join("assets", "models", path, model) for path in ["pretrained_v1", "pretrained_v2", "pretrained_custom"] for model in os.listdir(os.path.join("assets", "models", path)) if model.endswith(".pth") and ("D" in model or "G" in model)])
81
-
82
- separate_model = sorted([os.path.join("assets", "models", "uvr5", models) for models in os.listdir(os.path.join("assets", "models", "uvr5")) if models.endswith((".th", ".yaml", ".onnx"))])
83
- presets_file = sorted(list(f for f in os.listdir(os.path.join("assets", "presets")) if f.endswith(".json")))
84
- f0_file = sorted([os.path.abspath(os.path.join(root, f)) for root, _, files in os.walk(os.path.join("assets", "f0")) for f in files if f.endswith(".txt")])
85
-
86
- language, theme, edgetts, google_tts_voice, mdx_model, uvr_model, font = configs.get("language", "vi-VN"), configs.get("theme", "NoCrypt/miku"), configs.get("edge_tts", ["vi-VN-HoaiMyNeural", "vi-VN-NamMinhNeural"]), configs.get("google_tts_voice", ["vi", "en"]), configs.get("mdx_model", "MDXNET_Main"), (configs.get("demucs_model", "HD_MMI") + configs.get("mdx_model", "MDXNET_Main")), configs.get("font", "https://fonts.googleapis.com/css2?family=Courgette&display=swap")
87
-
88
- csv_path = os.path.join("assets", "spreadsheet.csv")
89
- logger.info(config.device)
90
-
91
- if "--allow_all_disk" in sys.argv:
92
- import win32api
93
-
94
- allow_disk = win32api.GetLogicalDriveStrings().split('\x00')[:-1]
95
- else: allow_disk = []
96
-
97
- if language == "vi-VN":
98
- import gradio.strings
99
- gradio.strings.en = {"RUNNING_LOCALLY": "* Chạy trên liên kết nội bộ: {}://{}:{}", "RUNNING_LOCALLY_SSR": "* Chạy trên liên kết nội bộ: {}://{}:{}, với SSR ⚡ (thử nghiệm, để tắt hãy dùng `ssr=False` trong `launch()`)", "SHARE_LINK_DISPLAY": "* Chạy trên liên kết công khai: {}", "COULD_NOT_GET_SHARE_LINK": "\nKhông thể tạo liên kết công khai. Vui lòng kiểm tra kết nối mạng của bạn hoặc trang trạng thái của chúng tôi: https://status.gradio.app.", "COULD_NOT_GET_SHARE_LINK_MISSING_FILE": "\nKhông thể tạo liên kết công khai. Thiếu tập tin: {}. \n\nVui lòng kiểm tra kết nối internet của bạn. Điều này có thể xảy ra nếu phần mềm chống vi-rút của bạn chặn việc tải xuống tệp này. Bạn có thể cài đặt thủ công bằng cách làm theo các bước sau: \n\n1. Tải xuống tệp này: {}\n2. Đổi tên tệp đã tải xuống thành: {}\n3. Di chuyển tệp đến vị trí này: {}", "COLAB_NO_LOCAL": "Không thể hiển thị giao diện nội bộ trên google colab, liên kết công khai đã được tạo.", "PUBLIC_SHARE_TRUE": "\nĐể tạo một liên kết công khai, hãy đặt `share=True` trong `launch()`.", "MODEL_PUBLICLY_AVAILABLE_URL": "Mô hình được cung cấp công khai tại: {} (có thể mất tới một phút để sử dụng được liên kết)", "GENERATING_PUBLIC_LINK": "Đang tạo liên kết công khai (có thể mất vài giây...):", "BETA_INVITE": "\nCảm ơn bạn đã là người dùng Gradio! Nếu bạn có thắc mắc hoặc phản hồi, vui lòng tham gia máy chủ Discord của chúng tôi và trò chuyện với chúng tôi: https://discord.gg/feTf9x3ZSB", "COLAB_DEBUG_TRUE": "Đã phát hiện thấy sổ tay Colab. Ô này sẽ chạy vô thời hạn để bạn có thể xem lỗi và nhật ký. " "Để tắt, hãy đặt debug=False trong launch().", "COLAB_DEBUG_FALSE": "Đã phát hiện thấy sổ tay Colab. Để hiển thị lỗi trong sổ ghi chép colab, hãy đặt debug=True trong launch()", "COLAB_WARNING": "Lưu ý: việc mở Chrome Inspector có thể làm hỏng bản demo trong sổ tay Colab.", "SHARE_LINK_MESSAGE": "\nLiên kết công khai sẽ hết hạn sau 72 giờ. Để nâng cấp GPU và lưu trữ vĩnh viễn miễn phí, hãy chạy `gradio deploy` từ terminal trong thư mục làm việc để triển khai lên huggingface (https://huggingface.co/spaces)", "INLINE_DISPLAY_BELOW": "Đang tải giao diện bên dưới...", "COULD_NOT_GET_SHARE_LINK_CHECKSUM": "\nKhông thể tạo liên kết công khai. Tổng kiểm tra không khớp cho tập tin: {}."}
100
-
101
- if os.path.exists(csv_path): cached_data = pd.read_csv(csv_path)
102
- else:
103
- cached_data = pd.read_csv(codecs.decode("uggcf://qbpf.tbbtyr.pbz/fcernqfurrgf/q/1gNHnDeRULtEfz1Yieaw14USUQjWJy0Oq9k0DrCrjApb/rkcbeg?sbezng=pfi&tvq=1977693859", "rot13"))
104
- cached_data.to_csv(csv_path, index=False)
105
-
106
- for _, row in cached_data.iterrows():
107
- filename = row['Filename']
108
- url = None
109
-
110
- for value in row.values:
111
- if isinstance(value, str) and "huggingface" in value:
112
- url = value
113
- break
114
-
115
- if url: models[filename] = url
116
-
117
-
118
-
119
- def gr_info(message):
120
- gr.Info(message, duration=2)
121
- logger.info(message)
122
-
123
- def gr_warning(message):
124
- gr.Warning(message, duration=2)
125
- logger.warning(message)
126
-
127
- def gr_error(message):
128
- gr.Error(message=message, duration=6)
129
- logger.error(message)
130
-
131
- def get_gpu_info():
132
- ngpu = torch.cuda.device_count()
133
- gpu_infos = [f"{i}: {torch.cuda.get_device_name(i)} ({int(torch.cuda.get_device_properties(i).total_memory / 1024 / 1024 / 1024 + 0.4)} GB)" for i in range(ngpu) if torch.cuda.is_available() or ngpu != 0]
134
- return "\n".join(gpu_infos) if len(gpu_infos) > 0 else translations["no_support_gpu"]
135
-
136
- def change_f0_choices():
137
- f0_file = sorted([os.path.abspath(os.path.join(root, f)) for root, _, files in os.walk(os.path.join("assets", "f0")) for f in files if f.endswith(".txt")])
138
- return {"value": f0_file[0] if len(f0_file) >= 1 else "", "choices": f0_file, "__type__": "update"}
139
-
140
- def change_audios_choices(input_audio):
141
- audios = sorted([os.path.abspath(os.path.join(root, f)) for root, _, files in os.walk("audios") for f in files if os.path.splitext(f)[1].lower() in (".wav", ".mp3", ".flac", ".ogg", ".opus", ".m4a", ".mp4", ".aac", ".alac", ".wma", ".aiff", ".webm", ".ac3")])
142
- return {"value": input_audio if input_audio != "" else (audios[0] if len(audios) >= 1 else ""), "choices": audios, "__type__": "update"}
143
-
144
- def change_separate_choices():
145
- return [{"choices": sorted([os.path.join("assets", "models", "uvr5", models) for models in os.listdir(os.path.join("assets", "models", "uvr5")) if model.endswith((".th", ".yaml", ".onnx"))]), "__type__": "update"}]
146
-
147
- def change_models_choices():
148
- model, index = sorted(list(model for model in os.listdir(os.path.join("assets", "weights")) if model.endswith((".pth", ".onnx")) and not model.startswith("G_") and not model.startswith("D_"))), sorted([os.path.join(root, name) for root, _, files in os.walk(os.path.join("assets", "logs"), topdown=False) for name in files if name.endswith(".index") and "trained" not in name])
149
- return [{"value": model[0] if len(model) >= 1 else "", "choices": model, "__type__": "update"}, {"value": index[0] if len(index) >= 1 else "", "choices": index, "__type__": "update"}]
150
-
151
- def change_allpretrained_choices():
152
- return [{"choices": sorted([os.path.join("assets", "models", path, model) for path in ["pretrained_v1", "pretrained_v2", "pretrained_custom"] for model in os.listdir(os.path.join("assets", "models", path)) if model.endswith(".pth") and ("D" in model or "G" in model)]), "__type__": "update"}]
153
-
154
- def change_pretrained_choices():
155
- return [{"choices": sorted([model for model in os.listdir(os.path.join("assets", "models", "pretrained_custom")) if model.endswith(".pth") and "D" in model]), "__type__": "update"}, {"choices": sorted([model for model in os.listdir(os.path.join("assets", "models", "pretrained_custom")) if model.endswith(".pth") and "G" in model]), "__type__": "update"}]
156
-
157
- def change_choices_del():
158
- return [{"choices": sorted(list(model for model in os.listdir(os.path.join("assets", "weights")) if model.endswith(".pth") and not model.startswith("G_") and not model.startswith("D_"))), "__type__": "update"}, {"choices": sorted([os.path.join("assets", "logs", f) for f in os.listdir(os.path.join("assets", "logs")) if "mute" not in f and os.path.isdir(os.path.join("assets", "logs", f))]), "__type__": "update"}]
159
-
160
- def change_preset_choices():
161
- return {"value": "", "choices": sorted(list(f for f in os.listdir(os.path.join("assets", "presets")) if f.endswith(".json"))), "__type__": "update"}
162
-
163
- def change_tts_voice_choices(google):
164
- return {"choices": google_tts_voice if google else edgetts, "value": google_tts_voice[0] if google else edgetts[0], "__type__": "update"}
165
-
166
- def change_backing_choices(backing, merge):
167
- if backing or merge: return {"value": False, "interactive": False, "__type__": "update"}
168
- elif not backing or not merge: return {"interactive": True, "__type__": "update"}
169
- else: gr_warning(translations["option_not_valid"])
170
-
171
- def change_download_choices(select):
172
- selects = [False]*10
173
-
174
- if select == translations["download_url"]: selects[0] = selects[1] = selects[2] = True
175
- elif select == translations["download_from_csv"]: selects[3] = selects[4] = True
176
- elif select == translations["search_models"]: selects[5] = selects[6] = True
177
- elif select == translations["upload"]: selects[9] = True
178
- else: gr_warning(translations["option_not_valid"])
179
-
180
- return [{"visible": selects[i], "__type__": "update"} for i in range(len(selects))]
181
-
182
- def change_download_pretrained_choices(select):
183
- selects = [False]*8
184
-
185
- if select == translations["download_url"]: selects[0] = selects[1] = selects[2] = True
186
- elif select == translations["list_model"]: selects[3] = selects[4] = selects[5] = True
187
- elif select == translations["upload"]: selects[6] = selects[7] = True
188
- else: gr_warning(translations["option_not_valid"])
189
-
190
- return [{"visible": selects[i], "__type__": "update"} for i in range(len(selects))]
191
-
192
- def get_index(model):
193
- model = os.path.basename(model).split("_")[0]
194
- return {"value": next((f for f in [os.path.join(root, name) for root, _, files in os.walk(os.path.join("assets", "logs"), topdown=False) for name in files if name.endswith(".index") and "trained" not in name] if model.split(".")[0] in f), ""), "__type__": "update"} if model else None
195
-
196
- def index_strength_show(index):
197
- return {"visible": index != "" and os.path.exists(index), "value": 0.5, "__type__": "update"}
198
-
199
- def hoplength_show(method, hybrid_method=None):
200
- show_hop_length_method = ["mangio-crepe-tiny", "mangio-crepe-small", "mangio-crepe-medium", "mangio-crepe-large", "mangio-crepe-full", "fcpe", "fcpe-legacy", "yin", "pyin"]
201
-
202
- if method in show_hop_length_method: visible = True
203
- elif method == "hybrid":
204
- methods_str = re.search("hybrid\[(.+)\]", hybrid_method)
205
- if methods_str: methods = [method.strip() for method in methods_str.group(1).split("+")]
206
-
207
- for i in methods:
208
- visible = i in show_hop_length_method
209
- if visible: break
210
- else: visible = False
211
-
212
- return {"visible": visible, "__type__": "update"}
213
-
214
- def visible(value):
215
- return {"visible": value, "__type__": "update"}
216
-
217
- def valueFalse_interactive(inp):
218
- return {"value": False, "interactive": inp, "__type__": "update"}
219
-
220
- def valueEmpty_visible1(inp1):
221
- return {"value": "", "visible": inp1, "__type__": "update"}
222
-
223
- def process_input(file_path):
224
- file_contents = ""
225
-
226
- if not file_path.endswith(".srt"):
227
- with open(file_path, "r", encoding="utf-8") as file:
228
- file_contents = file.read()
229
-
230
- gr_info(translations["upload_success"].format(name=translations["text"]))
231
- return file_contents
232
-
233
- def fetch_pretrained_data():
234
- response = requests.get(codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/wfba/phfgbz_cergenvarq.wfba", "rot13"))
235
- response.raise_for_status()
236
-
237
- return response.json()
238
-
239
- def update_sample_rate_dropdown(model):
240
- data = fetch_pretrained_data()
241
- if model != translations["success"]: return {"choices": list(data[model].keys()), "value": list(data[model].keys())[0], "__type__": "update"}
242
-
243
- def if_done(done, p):
244
- while 1:
245
- if p.poll() is None: sleep(0.5)
246
- else: break
247
-
248
- done[0] = True
249
-
250
- def restart_app():
251
- global app
252
-
253
- gr_info(translations["15s"])
254
- os.system("cls" if platform.system() == "Windows" else "clear")
255
-
256
- app.close()
257
- subprocess.run([python, os.path.join("main", "app", "app.py")] + sys.argv[1:])
258
-
259
- def change_language(lang):
260
- configs = json.load(open(configs_json, "r"))
261
- configs["language"] = lang
262
-
263
- with open(configs_json, "w") as f:
264
- json.dump(configs, f, indent=4)
265
-
266
- restart_app()
267
-
268
- def change_theme(theme):
269
- with open(configs_json, "r") as f:
270
- configs = json.load(f)
271
-
272
- configs["theme"] = theme
273
- with open(configs_json, "w") as f:
274
- json.dump(configs, f, indent=4)
275
-
276
- restart_app()
277
-
278
- def change_font(font):
279
- with open(configs_json, "r") as f:
280
- configs = json.load(f)
281
-
282
- configs["font"] = font
283
- with open(configs_json, "w") as f:
284
- json.dump(configs, f, indent=4)
285
-
286
- restart_app()
287
-
288
- def zip_file(name, pth, index):
289
- pth_path = os.path.join("assets", "weights", pth)
290
- if not pth or not os.path.exists(pth_path) or not pth.endswith((".pth", ".onnx")): return gr_warning(translations["provide_file"].format(filename=translations["model"]))
291
-
292
- zip_file_path = os.path.join("assets", "logs", name, name + ".zip")
293
- gr_info(translations["start"].format(start=translations["zip"]))
294
-
295
- import zipfile
296
- with zipfile.ZipFile(zip_file_path, 'w') as zipf:
297
- zipf.write(pth_path, os.path.basename(pth_path))
298
- if index: zipf.write(index, os.path.basename(index))
299
-
300
- gr_info(translations["success"])
301
- return {"visible": True, "value": zip_file_path, "__type__": "update"}
302
-
303
- def fetch_models_data(search):
304
- all_table_data = []
305
- page = 1
306
-
307
- while 1:
308
- try:
309
- response = requests.post(url=codecs.decode("uggcf://ibvpr-zbqryf.pbz/srgpu_qngn.cuc", "rot13"), data={"page": page, "search": search})
310
-
311
- if response.status_code == 200:
312
- table_data = response.json().get("table", "")
313
- if not table_data.strip(): break
314
- all_table_data.append(table_data)
315
- page += 1
316
- else:
317
- logger.debug(f"{translations['code_error']} {response.status_code}")
318
- break
319
- except json.JSONDecodeError:
320
- logger.debug(translations["json_error"])
321
- break
322
- except requests.RequestException as e:
323
- logger.debug(translations["requests_error"].format(e=e))
324
- break
325
- return all_table_data
326
-
327
- def search_models(name):
328
- gr_info(translations["start"].format(start=translations["search"]))
329
- tables = fetch_models_data(name)
330
-
331
- if len(tables) == 0:
332
- gr_info(translations["not_found"].format(name=name))
333
- return [None]*2
334
- else:
335
- model_options.clear()
336
-
337
- from bs4 import BeautifulSoup
338
-
339
- for table in tables:
340
- for row in BeautifulSoup(table, "html.parser").select("tr"):
341
- name_tag, url_tag = row.find("a", {"class": "fs-5"}), row.find("a", {"class": "btn btn-sm fw-bold btn-light ms-0 p-1 ps-2 pe-2"})
342
- url = url_tag["href"].replace("https://easyaivoice.com/run?url=", "")
343
- if "huggingface" in url:
344
- if name_tag and url_tag: model_options[name_tag.text.replace(".onnx", "").replace(".pth", "").replace(".index", "").replace(".zip", "").replace(" ", "_").replace("(", "").replace(")", "").replace("[", "").replace("]", "").replace(",", "").replace('"', "").replace("'", "").replace("|", "").strip()] = url
345
-
346
- gr_info(translations["found"].format(results=len(model_options)))
347
- return [{"value": "", "choices": model_options, "interactive": True, "visible": True, "__type__": "update"}, {"value": translations["downloads"], "visible": True, "__type__": "update"}]
348
-
349
- def move_files_from_directory(src_dir, dest_weights, dest_logs, model_name):
350
- for root, _, files in os.walk(src_dir):
351
- for file in files:
352
- file_path = os.path.join(root, file)
353
- if file.endswith(".index"):
354
- model_log_dir = os.path.join(dest_logs, model_name)
355
- os.makedirs(model_log_dir, exist_ok=True)
356
-
357
- filepath = os.path.join(model_log_dir, file.replace(' ', '_').replace('(', '').replace(')', '').replace('[', '').replace(']', '').replace(",", "").replace('"', "").replace("'", "").replace("|", "").strip())
358
- if os.path.exists(filepath): os.remove(filepath)
359
-
360
- shutil.move(file_path, filepath)
361
- elif file.endswith(".pth") and not file.startswith("D_") and not file.startswith("G_"):
362
- pth_path = os.path.join(dest_weights, model_name + ".pth")
363
- if os.path.exists(pth_path): os.remove(pth_path)
364
-
365
- shutil.move(file_path, pth_path)
366
- elif file.endswith(".onnx") and not file.startswith("D_") and not file.startswith("G_"):
367
- pth_path = os.path.join(dest_weights, model_name + ".onnx")
368
- if os.path.exists(pth_path): os.remove(pth_path)
369
-
370
- shutil.move(file_path, pth_path)
371
-
372
- def download_url(url):
373
- import yt_dlp
374
-
375
- if not url: return gr_warning(translations["provide_url"])
376
- if not os.path.exists("audios"): os.makedirs("audios", exist_ok=True)
377
-
378
- with warnings.catch_warnings():
379
- warnings.filterwarnings("ignore")
380
- ydl_opts = {"format": "bestaudio/best", "postprocessors": [{"key": "FFmpegExtractAudio", "preferredcodec": "wav", "preferredquality": "192"}], "quiet": True, "no_warnings": True, "noplaylist": True, "verbose": False}
381
-
382
- gr_info(translations["start"].format(start=translations["download_music"]))
383
-
384
- with yt_dlp.YoutubeDL(ydl_opts) as ydl:
385
- audio_output = os.path.join("audios", re.sub(r'\s+', '-', re.sub(r'[^\w\s\u4e00-\u9fff\uac00-\ud7af\u0400-\u04FF\u1100-\u11FF]', '', ydl.extract_info(url, download=False).get('title', 'video')).strip()))
386
- if os.path.exists(audio_output): shutil.rmtree(audio_output, ignore_errors=True)
387
-
388
- ydl_opts['outtmpl'] = audio_output
389
-
390
- with yt_dlp.YoutubeDL(ydl_opts) as ydl:
391
- audio_output = audio_output + ".wav"
392
- if os.path.exists(audio_output): os.remove(audio_output)
393
-
394
- ydl.download([url])
395
-
396
- gr_info(translations["success"])
397
- return [audio_output, audio_output, translations["success"]]
398
-
399
- def download_model(url=None, model=None):
400
- if not url: return gr_warning(translations["provide_url"])
401
- if not model: return gr_warning(translations["provide_name_is_save"])
402
-
403
- model = model.replace(".onnx", "").replace(".pth", "").replace(".index", "").replace(".zip", "").replace(" ", "_").replace("(", "").replace(")", "").replace("[", "").replace("]", "").replace(",", "").replace('"', "").replace("'", "").replace("|", "").strip()
404
- url = url.replace("/blob/", "/resolve/").replace("?download=true", "").strip()
405
-
406
- download_dir = os.path.join("download_model")
407
- weights_dir = os.path.join("assets", "weights")
408
- logs_dir = os.path.join("assets", "logs")
409
-
410
- if not os.path.exists(download_dir): os.makedirs(download_dir, exist_ok=True)
411
- if not os.path.exists(weights_dir): os.makedirs(weights_dir, exist_ok=True)
412
- if not os.path.exists(logs_dir): os.makedirs(logs_dir, exist_ok=True)
413
-
414
- try:
415
- gr_info(translations["start"].format(start=translations["download"]))
416
-
417
- if url.endswith(".pth"): huggingface.HF_download_file(url, os.path.join(weights_dir, f"{model}.pth"))
418
- elif url.endswith(".onnx"): huggingface.HF_download_file(url, os.path.join(weights_dir, f"{model}.onnx"))
419
- elif url.endswith(".index"):
420
- model_log_dir = os.path.join(logs_dir, model)
421
- os.makedirs(model_log_dir, exist_ok=True)
422
-
423
- huggingface.HF_download_file(url, os.path.join(model_log_dir, f"{model}.index"))
424
- elif url.endswith(".zip"):
425
- output_path = huggingface.HF_download_file(url, os.path.join(download_dir, model + ".zip"))
426
- shutil.unpack_archive(output_path, download_dir)
427
-
428
- move_files_from_directory(download_dir, weights_dir, logs_dir, model)
429
- else:
430
- if "drive.google.com" in url or "drive.usercontent.google.com" in url:
431
- file_id = None
432
-
433
- from main.tools import gdown
434
-
435
- if "/file/d/" in url: file_id = url.split("/d/")[1].split("/")[0]
436
- elif "open?id=" in url: file_id = url.split("open?id=")[1].split("/")[0]
437
- elif "/download?id=" in url: file_id = url.split("/download?id=")[1].split("&")[0]
438
-
439
- if file_id:
440
- file = gdown.gdown_download(id=file_id, output=download_dir)
441
- if file.endswith(".zip"): shutil.unpack_archive(file, download_dir)
442
-
443
- move_files_from_directory(download_dir, weights_dir, logs_dir, model)
444
- elif "mega.nz" in url:
445
- from main.tools import meganz
446
-
447
- meganz.mega_download_url(url, download_dir)
448
-
449
- file_download = next((f for f in os.listdir(download_dir)), None)
450
- if file_download.endswith(".zip"): shutil.unpack_archive(os.path.join(download_dir, file_download), download_dir)
451
-
452
- move_files_from_directory(download_dir, weights_dir, logs_dir, model)
453
- elif "mediafire.com" in url:
454
- from main.tools import mediafire
455
-
456
- file = mediafire.Mediafire_Download(url, download_dir)
457
- if file.endswith(".zip"): shutil.unpack_archive(file, download_dir)
458
-
459
- move_files_from_directory(download_dir, weights_dir, logs_dir, model)
460
- elif "pixeldrain.com" in url:
461
- from main.tools import pixeldrain
462
-
463
- file = pixeldrain.pixeldrain(url, download_dir)
464
- if file.endswith(".zip"): shutil.unpack_archive(file, download_dir)
465
-
466
- move_files_from_directory(download_dir, weights_dir, logs_dir, model)
467
- else:
468
- gr_warning(translations["not_support_url"])
469
- return translations["not_support_url"]
470
-
471
- gr_info(translations["success"])
472
- return translations["success"]
473
- except Exception as e:
474
- gr_error(message=translations["error_occurred"].format(e=e))
475
- logger.debug(e)
476
- return translations["error_occurred"].format(e=e)
477
- finally:
478
- shutil.rmtree(download_dir, ignore_errors=True)
479
-
480
- def save_drop_model(dropbox):
481
- weight_folder = os.path.join("assets", "weights")
482
- logs_folder = os.path.join("assets", "logs")
483
- save_model_temp = os.path.join("save_model_temp")
484
-
485
- if not os.path.exists(weight_folder): os.makedirs(weight_folder, exist_ok=True)
486
- if not os.path.exists(logs_folder): os.makedirs(logs_folder, exist_ok=True)
487
- if not os.path.exists(save_model_temp): os.makedirs(save_model_temp, exist_ok=True)
488
-
489
- shutil.move(dropbox, save_model_temp)
490
-
491
- try:
492
- file_name = os.path.basename(dropbox)
493
-
494
- if file_name.endswith(".pth") and file_name.endswith(".onnx") and file_name.endswith(".index"): gr_warning(translations["not_model"])
495
- else:
496
- if file_name.endswith(".zip"):
497
- shutil.unpack_archive(os.path.join(save_model_temp, file_name), save_model_temp)
498
- move_files_from_directory(save_model_temp, weight_folder, logs_folder, file_name.replace(".zip", ""))
499
- elif file_name.endswith((".pth", ".onnx")):
500
- output_file = os.path.join(weight_folder, file_name)
501
- if os.path.exists(output_file): os.remove(output_file)
502
-
503
- shutil.move(os.path.join(save_model_temp, file_name), output_file)
504
- elif file_name.endswith(".index"):
505
- def extract_name_model(filename):
506
- match = re.search(r"([A-Za-z]+)(?=_v|\.|$)", filename)
507
- return match.group(1) if match else None
508
-
509
- model_logs = os.path.join(logs_folder, extract_name_model(file_name))
510
- if not os.path.exists(model_logs): os.makedirs(model_logs, exist_ok=True)
511
- shutil.move(os.path.join(save_model_temp, file_name), model_logs)
512
- else:
513
- gr_warning(translations["unable_analyze_model"])
514
- return None
515
-
516
- gr_info(translations["upload_success"].format(name=translations["model"]))
517
- return None
518
- except Exception as e:
519
- gr_error(message=translations["error_occurred"].format(e=e))
520
- logger.debug(e)
521
- return None
522
- finally:
523
- shutil.rmtree(save_model_temp, ignore_errors=True)
524
-
525
- def download_pretrained_model(choices, model, sample_rate):
526
- pretraineds_custom_path = os.path.join("assets", "models", "pretrained_custom")
527
- if choices == translations["list_model"]:
528
- paths = fetch_pretrained_data()[model][sample_rate]
529
-
530
- if not os.path.exists(pretraineds_custom_path): os.makedirs(pretraineds_custom_path, exist_ok=True)
531
- url = codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/cergenvarq_phfgbz/", "rot13") + paths
532
-
533
- gr_info(translations["download_pretrain"])
534
- file = huggingface.HF_download_file(url.replace("/blob/", "/resolve/").replace("?download=true", "").strip(), os.path.join(pretraineds_custom_path, paths))
535
-
536
- if file.endswith(".zip"):
537
- shutil.unpack_archive(file, pretraineds_custom_path)
538
- os.remove(file)
539
-
540
- gr_info(translations["success"])
541
- return translations["success"]
542
- elif choices == translations["download_url"]:
543
- if not model: return gr_warning(translations["provide_pretrain"].format(dg="D"))
544
- if not sample_rate: return gr_warning(translations["provide_pretrain"].format(dg="G"))
545
-
546
- gr_info(translations["download_pretrain"])
547
-
548
- huggingface.HF_download_file(model.replace("/blob/", "/resolve/").replace("?download=true", "").strip(), pretraineds_custom_path)
549
- huggingface.HF_download_file(sample_rate.replace("/blob/", "/resolve/").replace("?download=true", "").strip(), pretraineds_custom_path)
550
-
551
- gr_info(translations["success"])
552
- return translations["success"]
553
-
554
- def fushion_model_pth(name, pth_1, pth_2, ratio):
555
- if not name.endswith(".pth"): name = name + ".pth"
556
-
557
- if not pth_1 or not os.path.exists(pth_1) or not pth_1.endswith(".pth"):
558
- gr_warning(translations["provide_file"].format(filename=translations["model"] + " 1"))
559
- return [translations["provide_file"].format(filename=translations["model"] + " 1"), None]
560
-
561
- if not pth_2 or not os.path.exists(pth_2) or not pth_2.endswith(".pth"):
562
- gr_warning(translations["provide_file"].format(filename=translations["model"] + " 2"))
563
- return [translations["provide_file"].format(filename=translations["model"] + " 2"), None]
564
-
565
- from collections import OrderedDict
566
-
567
- def extract(ckpt):
568
- a = ckpt["model"]
569
- opt = OrderedDict()
570
- opt["weight"] = {}
571
-
572
- for key in a.keys():
573
- if "enc_q" in key: continue
574
-
575
- opt["weight"][key] = a[key]
576
-
577
- return opt
578
-
579
- try:
580
- ckpt1 = torch.load(pth_1, map_location="cpu")
581
- ckpt2 = torch.load(pth_2, map_location="cpu")
582
-
583
- if ckpt1["sr"] != ckpt2["sr"]:
584
- gr_warning(translations["sr_not_same"])
585
- return [translations["sr_not_same"], None]
586
-
587
- cfg = ckpt1["config"]
588
- cfg_f0 = ckpt1["f0"]
589
- cfg_version = ckpt1["version"]
590
- cfg_sr = ckpt1["sr"]
591
-
592
- vocoder = ckpt1.get("vocoder", "Default")
593
-
594
- ckpt1 = extract(ckpt1) if "model" in ckpt1 else ckpt1["weight"]
595
- ckpt2 = extract(ckpt2) if "model" in ckpt2 else ckpt2["weight"]
596
-
597
- if sorted(list(ckpt1.keys())) != sorted(list(ckpt2.keys())):
598
- gr_warning(translations["architectures_not_same"])
599
- return [translations["architectures_not_same"], None]
600
-
601
- gr_info(translations["start"].format(start=translations["fushion_model"]))
602
-
603
- opt = OrderedDict()
604
- opt["weight"] = {}
605
-
606
- for key in ckpt1.keys():
607
- if key == "emb_g.weight" and ckpt1[key].shape != ckpt2[key].shape:
608
- min_shape0 = min(ckpt1[key].shape[0], ckpt2[key].shape[0])
609
- opt["weight"][key] = (ratio * (ckpt1[key][:min_shape0].float()) + (1 - ratio) * (ckpt2[key][:min_shape0].float())).half()
610
- else: opt["weight"][key] = (ratio * (ckpt1[key].float()) + (1 - ratio) * (ckpt2[key].float())).half()
611
-
612
- opt["config"] = cfg
613
- opt["sr"] = cfg_sr
614
- opt["f0"] = cfg_f0
615
- opt["version"] = cfg_version
616
- opt["infos"] = translations["model_fushion_info"].format(name=name, pth_1=pth_1, pth_2=pth_2, ratio=ratio)
617
- opt["vocoder"] = vocoder
618
-
619
- output_model = os.path.join("assets", "weights")
620
- if not os.path.exists(output_model): os.makedirs(output_model, exist_ok=True)
621
-
622
- torch.save(opt, os.path.join(output_model, name))
623
-
624
- gr_info(translations["success"])
625
- return [translations["success"], os.path.join(output_model, name)]
626
- except Exception as e:
627
- gr_error(message=translations["error_occurred"].format(e=e))
628
- logger.debug(e)
629
- return [e, None]
630
-
631
- def fushion_model(name, path_1, path_2, ratio):
632
- if not name:
633
- gr_warning(translations["provide_name_is_save"])
634
- return [translations["provide_name_is_save"], None]
635
-
636
- if path_1.endswith(".pth") and path_2.endswith(".pth"): return fushion_model_pth(name.replace(".onnx", ".pth"), path_1, path_2, ratio)
637
- else:
638
- gr_warning(translations["format_not_valid"])
639
- return [None, None]
640
-
641
- def onnx_export(model_path):
642
- from main.library.algorithm.onnx_export import onnx_exporter
643
-
644
- if not model_path.endswith(".pth"): model_path + ".pth"
645
- if not model_path or not os.path.exists(model_path) or not model_path.endswith(".pth"):
646
- gr_warning(translations["provide_file"].format(filename=translations["model"]))
647
- return [None, translations["provide_file"].format(filename=translations["model"])]
648
-
649
- try:
650
- gr_info(translations["start_onnx_export"])
651
- output = onnx_exporter(model_path, model_path.replace(".pth", ".onnx"), is_half=config.is_half, device=config.device)
652
-
653
- gr_info(translations["success"])
654
- return [output, translations["success"]]
655
- except Exception as e:
656
- return [None, e]
657
-
658
- def model_info(path):
659
- if not path or not os.path.exists(path) or os.path.isdir(path) or not path.endswith((".pth", ".onnx")): return gr_warning(translations["provide_file"].format(filename=translations["model"]))
660
-
661
- def prettify_date(date_str):
662
- if date_str == translations["not_found_create_time"]: return None
663
-
664
- try:
665
- return datetime.datetime.strptime(date_str, "%Y-%m-%dT%H:%M:%S.%f").strftime("%Y-%m-%d %H:%M:%S")
666
- except ValueError as e:
667
- logger.debug(e)
668
- return translations["format_not_valid"]
669
-
670
- if path.endswith(".pth"): model_data = torch.load(path, map_location=torch.device("cpu"))
671
- else:
672
- import onnx
673
-
674
- model = onnx.load(path)
675
- model_data = None
676
-
677
- for prop in model.metadata_props:
678
- if prop.key == "model_info":
679
- model_data = json.loads(prop.value)
680
- break
681
-
682
- gr_info(translations["read_info"])
683
-
684
- epochs = model_data.get("epoch", None)
685
- if epochs is None:
686
- epochs = model_data.get("info", None)
687
- try:
688
- epoch = epochs.replace("epoch", "").replace("e", "").isdigit()
689
- if epoch and epochs is None: epochs = translations["not_found"].format(name=translations["epoch"])
690
- except:
691
- pass
692
-
693
- steps = model_data.get("step", translations["not_found"].format(name=translations["step"]))
694
- sr = model_data.get("sr", translations["not_found"].format(name=translations["sr"]))
695
- f0 = model_data.get("f0", translations["not_found"].format(name=translations["f0"]))
696
- version = model_data.get("version", translations["not_found"].format(name=translations["version"]))
697
- creation_date = model_data.get("creation_date", translations["not_found_create_time"])
698
- model_hash = model_data.get("model_hash", translations["not_found"].format(name="model_hash"))
699
- pitch_guidance = translations["trained_f0"] if f0 else translations["not_f0"]
700
- creation_date_str = prettify_date(creation_date) if creation_date else translations["not_found_create_time"]
701
- model_name = model_data.get("model_name", translations["unregistered"])
702
- model_author = model_data.get("author", translations["not_author"])
703
- vocoder = model_data.get("vocoder", "Default")
704
-
705
- gr_info(translations["success"])
706
- return translations["model_info"].format(model_name=model_name, model_author=model_author, epochs=epochs, steps=steps, version=version, sr=sr, pitch_guidance=pitch_guidance, model_hash=model_hash, creation_date_str=creation_date_str, vocoder=vocoder)
707
-
708
- def audio_effects(input_path, output_path, resample, resample_sr, chorus_depth, chorus_rate, chorus_mix, chorus_delay, chorus_feedback, distortion_drive, reverb_room_size, reverb_damping, reverb_wet_level, reverb_dry_level, reverb_width, reverb_freeze_mode, pitch_shift, delay_seconds, delay_feedback, delay_mix, compressor_threshold, compressor_ratio, compressor_attack_ms, compressor_release_ms, limiter_threshold, limiter_release, gain_db, bitcrush_bit_depth, clipping_threshold, phaser_rate_hz, phaser_depth, phaser_centre_frequency_hz, phaser_feedback, phaser_mix, bass_boost_db, bass_boost_frequency, treble_boost_db, treble_boost_frequency, fade_in_duration, fade_out_duration, export_format, chorus, distortion, reverb, delay, compressor, limiter, gain, bitcrush, clipping, phaser, treble_bass_boost, fade_in_out, audio_combination, audio_combination_input):
709
- if not input_path or not os.path.exists(input_path) or os.path.isdir(input_path):
710
- gr_warning(translations["input_not_valid"])
711
- return None
712
-
713
- if not output_path:
714
- gr_warning(translations["output_not_valid"])
715
- return None
716
-
717
- if os.path.isdir(output_path): output_path = os.path.join(output_path, f"audio_effects.{export_format}")
718
- output_dir = os.path.dirname(output_path) or output_path
719
-
720
- if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True)
721
- if os.path.exists(output_path): os.remove(output_path)
722
-
723
- gr_info(translations["start"].format(start=translations["apply_effect"]))
724
- subprocess.run([python, "main/inference/audio_effects.py", "--input_path", input_path, "--output_path", output_path, "--resample", str(resample), "--resample_sr", str(resample_sr), "--chorus_depth", str(chorus_depth), "--chorus_rate", str(chorus_rate), "--chorus_mix", str(chorus_mix), "--chorus_delay", str(chorus_delay), "--chorus_feedback", str(chorus_feedback), "--drive_db", str(distortion_drive), "--reverb_room_size", str(reverb_room_size), "--reverb_damping", str(reverb_damping), "--reverb_wet_level", str(reverb_wet_level), "--reverb_dry_level", str(reverb_dry_level), "--reverb_width", str(reverb_width), "--reverb_freeze_mode", str(reverb_freeze_mode), "--pitch_shift", str(pitch_shift), "--delay_seconds", str(delay_seconds), "--delay_feedback", str(delay_feedback), "--delay_mix", str(delay_mix), "--compressor_threshold", str(compressor_threshold), "--compressor_ratio", str(compressor_ratio), "--compressor_attack_ms", str(compressor_attack_ms), "--compressor_release_ms", str(compressor_release_ms), "--limiter_threshold", str(limiter_threshold), "--limiter_release", str(limiter_release), "--gain_db", str(gain_db), "--bitcrush_bit_depth", str(bitcrush_bit_depth), "--clipping_threshold", str(clipping_threshold), "--phaser_rate_hz", str(phaser_rate_hz), "--phaser_depth", str(phaser_depth), "--phaser_centre_frequency_hz", str(phaser_centre_frequency_hz), "--phaser_feedback", str(phaser_feedback), "--phaser_mix", str(phaser_mix), "--bass_boost_db", str(bass_boost_db), "--bass_boost_frequency", str(bass_boost_frequency), "--treble_boost_db", str(treble_boost_db), "--treble_boost_frequency", str(treble_boost_frequency), "--fade_in_duration", str(fade_in_duration), "--fade_out_duration", str(fade_out_duration), "--export_format", export_format, "--chorus", str(chorus), "--distortion", str(distortion), "--reverb", str(reverb), "--pitchshift", str(pitch_shift != 0), "--delay", str(delay), "--compressor", str(compressor), "--limiter", str(limiter), "--gain", str(gain), "--bitcrush", str(bitcrush), "--clipping", str(clipping), "--phaser", str(phaser), "--treble_bass_boost", str(treble_bass_boost), "--fade_in_out", str(fade_in_out), "--audio_combination", str(audio_combination), "--audio_combination_input", audio_combination_input])
725
-
726
- gr_info(translations["success"])
727
- return output_path.replace("wav", export_format)
728
-
729
- def synthesize_tts(prompt, voice, speed, output, pitch, google):
730
- if not google:
731
- from edge_tts import Communicate
732
-
733
- asyncio.run(Communicate(text=prompt, voice=voice, rate=f"+{speed}%" if speed >= 0 else f"{speed}%", pitch=f"+{pitch}Hz" if pitch >= 0 else f"{pitch}Hz").save(output))
734
- else:
735
- response = requests.get(codecs.decode("uggcf://genafyngr.tbbtyr.pbz/genafyngr_ggf", "rot13"), params={"ie": "UTF-8", "q": prompt, "tl": voice, "ttsspeed": speed, "client": "tw-ob"}, headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36"})
736
-
737
- if response.status_code == 200:
738
- with open(output, "wb") as f:
739
- f.write(response.content)
740
-
741
- if pitch != 0 or speed != 0:
742
- y, sr = librosa.load(output, sr=None)
743
-
744
- if pitch != 0: y = librosa.effects.pitch_shift(y, sr=sr, n_steps=pitch)
745
- if speed != 0: y = librosa.effects.time_stretch(y, rate=speed)
746
-
747
- sf.write(file=output, data=y, samplerate=sr, format=os.path.splitext(os.path.basename(output))[-1].lower().replace('.', ''))
748
- else: gr_error(f"{response.status_code}, {response.text}")
749
-
750
- def time_stretch(y, sr, target_duration):
751
- rate = (len(y) / sr) / target_duration
752
- if rate != 1.0: y = librosa.effects.time_stretch(y=y.astype(np.float32), rate=rate)
753
-
754
- n_target = int(round(target_duration * sr))
755
- return np.pad(y, (0, n_target - len(y))) if len(y) < n_target else y[:n_target]
756
-
757
- def pysrttime_to_seconds(t):
758
- return (t.hours * 60 + t.minutes) * 60 + t.seconds + t.milliseconds / 1000
759
-
760
- def srt_tts(srt_file, out_file, voice, rate = 0, sr = 24000, google = False):
761
- import pysrt
762
- import tempfile
763
-
764
- subs = pysrt.open(srt_file)
765
- if not subs: raise ValueError(translations["srt"])
766
-
767
- final_audio = np.zeros(int(round(pysrttime_to_seconds(subs[-1].end) * sr)), dtype=np.float32)
768
-
769
- with tempfile.TemporaryDirectory() as tempdir:
770
- for idx, seg in enumerate(subs):
771
- wav_path = os.path.join(tempdir, f"seg_{idx}.wav")
772
- synthesize_tts(" ".join(seg.text.splitlines()), voice, 0, wav_path, rate, google)
773
-
774
- audio, file_sr = sf.read(wav_path, dtype=np.float32)
775
- if file_sr != sr: audio = np.interp(np.linspace(0, len(audio) - 1, int(len(audio) * sr / file_sr)), np.arange(len(audio)), audio)
776
- adjusted = time_stretch(audio, sr, pysrttime_to_seconds(seg.duration))
777
-
778
- start_sample = int(round(pysrttime_to_seconds(seg.start) * sr))
779
- end_sample = start_sample + adjusted.shape[0]
780
-
781
- if end_sample > final_audio.shape[0]:
782
- adjusted = adjusted[: final_audio.shape[0] - start_sample]
783
- end_sample = final_audio.shape[0]
784
-
785
- final_audio[start_sample:end_sample] += adjusted
786
-
787
- sf.write(out_file, final_audio, sr)
788
-
789
- def TTS(prompt, voice, speed, output, pitch, google, srt_input):
790
- if not srt_input: srt_input = ""
791
-
792
- if not prompt and not srt_input.endswith(".srt"):
793
- gr_warning(translations["enter_the_text"])
794
- return None
795
-
796
- if not voice:
797
- gr_warning(translations["choose_voice"])
798
- return None
799
-
800
- if not output:
801
- gr_warning(translations["output_not_valid"])
802
- return None
803
-
804
- if os.path.isdir(output): output = os.path.join(output, f"tts.wav")
805
- gr_info(translations["convert"].format(name=translations["text"]))
806
-
807
- output_dir = os.path.dirname(output) or output
808
- if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True)
809
-
810
- if srt_input.endswith(".srt"): srt_tts(srt_input, output, voice, 0, 24000, google)
811
- else: synthesize_tts(prompt, voice, speed, output, pitch, google)
812
-
813
- gr_info(translations["success"])
814
- return output
815
-
816
- def separator_music(input, output_audio, format, shifts, segments_size, overlap, clean_audio, clean_strength, denoise, separator_model, kara_model, backing, reverb, backing_reverb, hop_length, batch_size, sample_rate):
817
- output = os.path.dirname(output_audio) or output_audio
818
-
819
- if not input or not os.path.exists(input) or os.path.isdir(input):
820
- gr_warning(translations["input_not_valid"])
821
- return [None]*4
822
-
823
- if not os.path.exists(output):
824
- gr_warning(translations["output_not_valid"])
825
- return [None]*4
826
-
827
- if not os.path.exists(output): os.makedirs(output)
828
- gr_info(translations["start"].format(start=translations["separator_music"]))
829
-
830
- subprocess.run([python, "main/inference/separator_music.py", "--input_path", input, "--output_path", output, "--format", format, "--shifts", str(shifts), "--segments_size", str(segments_size), "--overlap", str(overlap), "--mdx_hop_length", str(hop_length), "--mdx_batch_size", str(batch_size), "--clean_audio", str(clean_audio), "--clean_strength", str(clean_strength), "--kara_model", kara_model, "--backing", str(backing), "--mdx_denoise", str(denoise), "--reverb", str(reverb), "--backing_reverb", str(backing_reverb), "--model_name", separator_model, "--sample_rate", str(sample_rate)])
831
- gr_info(translations["success"])
832
-
833
- filename, _ = os.path.splitext(os.path.basename(input))
834
- output = os.path.join(output, filename)
835
-
836
- return [os.path.join(output, f"Original_Vocals_No_Reverb.{format}") if reverb else os.path.join(output, f"Original_Vocals.{format}"), os.path.join(output, f"Instruments.{format}"), (os.path.join(output, f"Main_Vocals_No_Reverb.{format}") if reverb else os.path.join(output, f"Main_Vocals.{format}") if backing else None), (os.path.join(output, f"Backing_Vocals_No_Reverb.{format}") if backing_reverb else os.path.join(output, f"Backing_Vocals.{format}") if backing else None)] if os.path.isfile(input) else [None]*4
837
-
838
- def convert(pitch, filter_radius, index_rate, volume_envelope, protect, hop_length, f0_method, input_path, output_path, pth_path, index_path, f0_autotune, clean_audio, clean_strength, export_format, embedder_model, resample_sr, split_audio, f0_autotune_strength, checkpointing, onnx_f0_mode, embedders_mode, formant_shifting, formant_qfrency, formant_timbre, f0_file):
839
- subprocess.run([python, "main/inference/convert.py", "--pitch", str(pitch), "--filter_radius", str(filter_radius), "--index_rate", str(index_rate), "--volume_envelope", str(volume_envelope), "--protect", str(protect), "--hop_length", str(hop_length), "--f0_method", f0_method, "--input_path", input_path, "--output_path", output_path, "--pth_path", pth_path, "--index_path", index_path if index_path else "", "--f0_autotune", str(f0_autotune), "--clean_audio", str(clean_audio), "--clean_strength", str(clean_strength), "--export_format", export_format, "--embedder_model", embedder_model, "--resample_sr", str(resample_sr), "--split_audio", str(split_audio), "--f0_autotune_strength", str(f0_autotune_strength), "--checkpointing", str(checkpointing), "--f0_onnx", str(onnx_f0_mode), "--embedders_mode", embedders_mode, "--formant_shifting", str(formant_shifting), "--formant_qfrency", str(formant_qfrency), "--formant_timbre", str(formant_timbre), "--f0_file", f0_file])
840
-
841
- def convert_audio(clean, autotune, use_audio, use_original, convert_backing, not_merge_backing, merge_instrument, pitch, clean_strength, model, index, index_rate, input, output, format, method, hybrid_method, hop_length, embedders, custom_embedders, resample_sr, filter_radius, volume_envelope, protect, split_audio, f0_autotune_strength, input_audio_name, checkpointing, onnx_f0_mode, formant_shifting, formant_qfrency, formant_timbre, f0_file, embedders_mode):
842
- model_path = os.path.join("assets", "weights", model)
843
-
844
- return_none = [None]*6
845
- return_none[5] = {"visible": True, "__type__": "update"}
846
-
847
- if not use_audio:
848
- if merge_instrument or not_merge_backing or convert_backing or use_original:
849
- gr_warning(translations["turn_on_use_audio"])
850
- return return_none
851
-
852
- if use_original:
853
- if convert_backing:
854
- gr_warning(translations["turn_off_convert_backup"])
855
- return return_none
856
- elif not_merge_backing:
857
- gr_warning(translations["turn_off_merge_backup"])
858
- return return_none
859
-
860
- if not model or not os.path.exists(model_path) or os.path.isdir(model_path) or not model.endswith((".pth", ".onnx")):
861
- gr_warning(translations["provide_file"].format(filename=translations["model"]))
862
- return return_none
863
-
864
- f0method, embedder_model = (method if method != "hybrid" else hybrid_method), (embedders if embedders != "custom" else custom_embedders)
865
-
866
- if use_audio:
867
- output_audio = os.path.join("audios", input_audio_name)
868
-
869
- from main.library.utils import pydub_convert, pydub_load
870
-
871
- def get_audio_file(label):
872
- matching_files = [f for f in os.listdir(output_audio) if label in f]
873
-
874
- if not matching_files: return translations["notfound"]
875
- return os.path.join(output_audio, matching_files[0])
876
-
877
- output_path = os.path.join(output_audio, f"Convert_Vocals.{format}")
878
- output_backing = os.path.join(output_audio, f"Convert_Backing.{format}")
879
- output_merge_backup = os.path.join(output_audio, f"Vocals+Backing.{format}")
880
- output_merge_instrument = os.path.join(output_audio, f"Vocals+Instruments.{format}")
881
-
882
- if os.path.exists(output_audio): os.makedirs(output_audio, exist_ok=True)
883
- if os.path.exists(output_path): os.remove(output_path)
884
-
885
- if use_original:
886
- original_vocal = get_audio_file('Original_Vocals_No_Reverb.')
887
-
888
- if original_vocal == translations["notfound"]: original_vocal = get_audio_file('Original_Vocals.')
889
-
890
- if original_vocal == translations["notfound"]:
891
- gr_warning(translations["not_found_original_vocal"])
892
- return return_none
893
-
894
- input_path = original_vocal
895
- else:
896
- main_vocal = get_audio_file('Main_Vocals_No_Reverb.')
897
- backing_vocal = get_audio_file('Backing_Vocals_No_Reverb.')
898
-
899
- if main_vocal == translations["notfound"]: main_vocal = get_audio_file('Main_Vocals.')
900
- if not not_merge_backing and backing_vocal == translations["notfound"]: backing_vocal = get_audio_file('Backing_Vocals.')
901
-
902
- if main_vocal == translations["notfound"]:
903
- gr_warning(translations["not_found_main_vocal"])
904
- return return_none
905
-
906
- if not not_merge_backing and backing_vocal == translations["notfound"]:
907
- gr_warning(translations["not_found_backing_vocal"])
908
- return return_none
909
-
910
- input_path = main_vocal
911
- backing_path = backing_vocal
912
-
913
- gr_info(translations["convert_vocal"])
914
-
915
- convert(pitch, filter_radius, index_rate, volume_envelope, protect, hop_length, f0method, input_path, output_path, model_path, index, autotune, clean, clean_strength, format, embedder_model, resample_sr, split_audio, f0_autotune_strength, checkpointing, onnx_f0_mode, embedders_mode, formant_shifting, formant_qfrency, formant_timbre, f0_file)
916
-
917
- gr_info(translations["convert_success"])
918
-
919
- if convert_backing:
920
- if os.path.exists(output_backing): os.remove(output_backing)
921
-
922
- gr_info(translations["convert_backup"])
923
-
924
- convert(pitch, filter_radius, index_rate, volume_envelope, protect, hop_length, f0method, backing_path, output_backing, model_path, index, autotune, clean, clean_strength, format, embedder_model, resample_sr, split_audio, f0_autotune_strength, checkpointing, onnx_f0_mode, embedders_mode, formant_shifting, formant_qfrency, formant_timbre, f0_file)
925
-
926
- gr_info(translations["convert_backup_success"])
927
-
928
- try:
929
- if not not_merge_backing and not use_original:
930
- backing_source = output_backing if convert_backing else backing_vocal
931
-
932
- if os.path.exists(output_merge_backup): os.remove(output_merge_backup)
933
-
934
- gr_info(translations["merge_backup"])
935
-
936
- pydub_convert(pydub_load(output_path)).overlay(pydub_convert(pydub_load(backing_source))).export(output_merge_backup, format=format)
937
-
938
- gr_info(translations["merge_success"])
939
-
940
- if merge_instrument:
941
- vocals = output_merge_backup if not not_merge_backing and not use_original else output_path
942
-
943
- if os.path.exists(output_merge_instrument): os.remove(output_merge_instrument)
944
-
945
- gr_info(translations["merge_instruments_process"])
946
-
947
- instruments = get_audio_file('Instruments.')
948
-
949
- if instruments == translations["notfound"]:
950
- gr_warning(translations["not_found_instruments"])
951
- output_merge_instrument = None
952
- else: pydub_convert(pydub_load(instruments)).overlay(pydub_convert(pydub_load(vocals))).export(output_merge_instrument, format=format)
953
-
954
- gr_info(translations["merge_success"])
955
- except:
956
- return return_none
957
-
958
- return [(None if use_original else output_path), output_backing, (None if not_merge_backing and use_original else output_merge_backup), (output_path if use_original else None), (output_merge_instrument if merge_instrument else None), {"visible": True, "__type__": "update"}]
959
- else:
960
- if not input or not os.path.exists(input) or os.path.isdir(input):
961
- gr_warning(translations["input_not_valid"])
962
- return return_none
963
-
964
- if not output:
965
- gr_warning(translations["output_not_valid"])
966
- return return_none
967
-
968
- output = output.replace("wav", format)
969
-
970
- if os.path.isdir(input):
971
- gr_info(translations["is_folder"])
972
-
973
- if not [f for f in os.listdir(input) if f.lower().endswith(("wav", "mp3", "flac", "ogg", "opus", "m4a", "mp4", "aac", "alac", "wma", "aiff", "webm", "ac3"))]:
974
- gr_warning(translations["not_found_in_folder"])
975
- return return_none
976
-
977
- gr_info(translations["batch_convert"])
978
-
979
- output_dir = os.path.dirname(output) or output
980
- convert(pitch, filter_radius, index_rate, volume_envelope, protect, hop_length, f0method, input, output_dir, model_path, index, autotune, clean, clean_strength, format, embedder_model, resample_sr, split_audio, f0_autotune_strength, checkpointing, onnx_f0_mode, embedders_mode, formant_shifting, formant_qfrency, formant_timbre, f0_file)
981
-
982
- gr_info(translations["batch_convert_success"])
983
-
984
- return return_none
985
- else:
986
- output_dir = os.path.dirname(output) or output
987
-
988
- if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True)
989
- if os.path.exists(output): os.remove(output)
990
-
991
- gr_info(translations["convert_vocal"])
992
-
993
- convert(pitch, filter_radius, index_rate, volume_envelope, protect, hop_length, f0method, input, output, model_path, index, autotune, clean, clean_strength, format, embedder_model, resample_sr, split_audio, f0_autotune_strength, checkpointing, onnx_f0_mode, embedders_mode, formant_shifting, formant_qfrency, formant_timbre, f0_file)
994
-
995
- gr_info(translations["convert_success"])
996
-
997
- return_none[0] = output
998
- return return_none
999
-
1000
- def convert_selection(clean, autotune, use_audio, use_original, convert_backing, not_merge_backing, merge_instrument, pitch, clean_strength, model, index, index_rate, input, output, format, method, hybrid_method, hop_length, embedders, custom_embedders, resample_sr, filter_radius, volume_envelope, protect, split_audio, f0_autotune_strength, checkpointing, onnx_f0_mode, formant_shifting, formant_qfrency, formant_timbre, f0_file, embedders_mode):
1001
- if use_audio:
1002
- gr_info(translations["search_separate"])
1003
-
1004
- choice = [f for f in os.listdir("audios") if os.path.isdir(os.path.join("audios", f))]
1005
-
1006
- gr_info(translations["found_choice"].format(choice=len(choice)))
1007
-
1008
- if len(choice) == 0:
1009
- gr_warning(translations["separator==0"])
1010
-
1011
- return [{"choices": [], "value": "", "interactive": False, "visible": False, "__type__": "update"}, None, None, None, None, None, {"visible": True, "__type__": "update"}]
1012
- elif len(choice) == 1:
1013
- convert_output = convert_audio(clean, autotune, use_audio, use_original, convert_backing, not_merge_backing, merge_instrument, pitch, clean_strength, model, index, index_rate, None, None, format, method, hybrid_method, hop_length, embedders, custom_embedders, resample_sr, filter_radius, volume_envelope, protect, split_audio, f0_autotune_strength, choice[0], checkpointing, onnx_f0_mode, formant_shifting, formant_qfrency, formant_timbre, f0_file, embedders_mode)
1014
-
1015
- return [{"choices": [], "value": "", "interactive": False, "visible": False, "__type__": "update"}, convert_output[0], convert_output[1], convert_output[2], convert_output[3], convert_output[4], {"visible": True, "__type__": "update"}]
1016
- else: return [{"choices": choice, "value": "", "interactive": True, "visible": True, "__type__": "update"}, None, None, None, None, None, {"visible": False, "__type__": "update"}]
1017
- else:
1018
- main_convert = convert_audio(clean, autotune, use_audio, use_original, convert_backing, not_merge_backing, merge_instrument, pitch, clean_strength, model, index, index_rate, input, output, format, method, hybrid_method, hop_length, embedders, custom_embedders, resample_sr, filter_radius, volume_envelope, protect, split_audio, f0_autotune_strength, None, checkpointing, onnx_f0_mode, formant_shifting, formant_qfrency, formant_timbre, f0_file, embedders_mode)
1019
-
1020
- return [{"choices": [], "value": "", "interactive": False, "visible": False, "__type__": "update"}, main_convert[0], None, None, None, None, {"visible": True, "__type__": "update"}]
1021
-
1022
- def convert_with_whisper(num_spk, model_size, cleaner, clean_strength, autotune, f0_autotune_strength, checkpointing, model_1, model_2, model_index_1, model_index_2, pitch_1, pitch_2, index_strength_1, index_strength_2, export_format, input_audio, output_audio, onnx_f0_mode, method, hybrid_method, hop_length, embed_mode, embedders, custom_embedders, resample_sr, filter_radius, volume_envelope, protect, formant_shifting, formant_qfrency_1, formant_timbre_1, formant_qfrency_2, formant_timbre_2):
1023
- from pydub import AudioSegment
1024
- from sklearn.cluster import AgglomerativeClustering
1025
-
1026
- from main.library.speaker_diarization.audio import Audio
1027
- from main.library.speaker_diarization.segment import Segment
1028
- from main.library.speaker_diarization.whisper import load_model
1029
- from main.library.utils import check_spk_diarization, pydub_convert, pydub_load
1030
- from main.library.speaker_diarization.embedding import SpeechBrainPretrainedSpeakerEmbedding
1031
-
1032
- check_spk_diarization(model_size)
1033
- model_pth_1, model_pth_2 = os.path.join("assets", "weights", model_1), os.path.join("assets", "weights", model_2)
1034
-
1035
- if (not model_1 or not os.path.exists(model_pth_1) or os.path.isdir(model_pth_1) or not model_pth_1.endswith((".pth", ".onnx"))) and (not model_2 or not os.path.exists(model_pth_2) or os.path.isdir(model_pth_2) or not model_pth_2.endswith((".pth", ".onnx"))):
1036
- gr_warning(translations["provide_file"].format(filename=translations["model"]))
1037
- return None
1038
-
1039
- if not model_1: model_pth_1 = model_pth_2
1040
- if not model_2: model_pth_2 = model_pth_1
1041
-
1042
- if not input_audio or not os.path.exists(input_audio) or os.path.isdir(input_audio):
1043
- gr_warning(translations["input_not_valid"])
1044
- return None
1045
-
1046
- if not output_audio:
1047
- gr_warning(translations["output_not_valid"])
1048
- return None
1049
-
1050
- if os.path.exists(output_audio): os.remove(output_audio)
1051
- gr_info(translations["start_whisper"])
1052
-
1053
- try:
1054
- audio = Audio()
1055
-
1056
- embedding_model = SpeechBrainPretrainedSpeakerEmbedding(device=config.device)
1057
- segments = load_model(model_size, device=config.device).transcribe(input_audio, fp16=configs.get("fp16", False), word_timestamps=True)["segments"]
1058
-
1059
- y, sr = librosa.load(input_audio, sr=None)
1060
- duration = len(y) / sr
1061
-
1062
- def segment_embedding(segment):
1063
- waveform, _ = audio.crop(input_audio, Segment(segment["start"], min(duration, segment["end"])))
1064
- return embedding_model(waveform.mean(dim=0, keepdim=True)[None] if waveform.shape[0] == 2 else waveform[None])
1065
-
1066
- def time(secs):
1067
- return datetime.timedelta(seconds=round(secs))
1068
-
1069
- def merge_audio(files_list, time_stamps, original_file_path, output_path, format):
1070
- def extract_number(filename):
1071
- match = re.search(r'_(\d+)', filename)
1072
- return int(match.group(1)) if match else 0
1073
-
1074
- total_duration = len(pydub_load(original_file_path))
1075
- combined = AudioSegment.empty()
1076
- current_position = 0
1077
-
1078
- for file, (start_i, end_i) in zip(sorted(files_list, key=extract_number), time_stamps):
1079
- if start_i > current_position: combined += AudioSegment.silent(duration=start_i - current_position)
1080
-
1081
- combined += pydub_load(file)
1082
- current_position = end_i
1083
-
1084
- if current_position < total_duration: combined += AudioSegment.silent(duration=total_duration - current_position)
1085
- combined.export(output_path, format=format)
1086
-
1087
- return output_path
1088
-
1089
- embeddings = np.zeros(shape=(len(segments), 192))
1090
- for i, segment in enumerate(segments):
1091
- embeddings[i] = segment_embedding(segment)
1092
-
1093
- labels = AgglomerativeClustering(num_spk).fit(np.nan_to_num(embeddings)).labels_
1094
- for i in range(len(segments)):
1095
- segments[i]["speaker"] = 'SPEAKER ' + str(labels[i] + 1)
1096
-
1097
- merged_segments, current_text = [], []
1098
- current_speaker, current_start = None, None
1099
-
1100
- for i, segment in enumerate(segments):
1101
- speaker = segment["speaker"]
1102
- start_time = segment["start"]
1103
- text = segment["text"][1:]
1104
-
1105
- if speaker == current_speaker:
1106
- current_text.append(text)
1107
- end_time = segment["end"]
1108
- else:
1109
- if current_speaker is not None: merged_segments.append({"speaker": current_speaker, "start": current_start, "end": end_time, "text": " ".join(current_text)})
1110
-
1111
- current_speaker = speaker
1112
- current_start = start_time
1113
- current_text = [text]
1114
- end_time = segment["end"]
1115
-
1116
- if current_speaker is not None: merged_segments.append({"speaker": current_speaker, "start": current_start, "end": end_time, "text": " ".join(current_text)})
1117
-
1118
- gr_info(translations["whisper_done"])
1119
-
1120
- x = ""
1121
- for segment in merged_segments:
1122
- x += f"\n{segment['speaker']} {str(time(segment['start']))} - {str(time(segment['end']))}\n"
1123
- x += segment["text"] + "\n"
1124
-
1125
- logger.info(x)
1126
-
1127
- gr_info(translations["process_audio"])
1128
-
1129
- audio = pydub_convert(pydub_load(input_audio))
1130
- output_folder = "audios_temp"
1131
-
1132
- if os.path.exists(output_folder): shutil.rmtree(output_folder, ignore_errors=True)
1133
- for f in [output_folder, os.path.join(output_folder, "1"), os.path.join(output_folder, "2")]:
1134
- os.makedirs(f, exist_ok=True)
1135
-
1136
- time_stamps, processed_segments = [], []
1137
- for i, segment in enumerate(merged_segments):
1138
- start_ms = int(segment["start"] * 1000)
1139
- end_ms = int(segment["end"] * 1000)
1140
-
1141
- index = i + 1
1142
-
1143
- segment_filename = os.path.join(output_folder, "1" if i % 2 == 1 else "2", f"segment_{index}.wav")
1144
- audio[start_ms:end_ms].export(segment_filename, format="wav")
1145
-
1146
- processed_segments.append(os.path.join(output_folder, "1" if i % 2 == 1 else "2", f"segment_{index}_output.wav"))
1147
- time_stamps.append((start_ms, end_ms))
1148
-
1149
- f0method, embedder_model = (method if method != "hybrid" else hybrid_method), (embedders if embedders != "custom" else custom_embedders)
1150
-
1151
- gr_info(translations["process_done_start_convert"])
1152
-
1153
- convert(pitch_1, filter_radius, index_strength_1, volume_envelope, protect, hop_length, f0method, os.path.join(output_folder, "1"), output_folder, model_pth_1, model_index_1, autotune, cleaner, clean_strength, "wav", embedder_model, resample_sr, False, f0_autotune_strength, checkpointing, onnx_f0_mode, embed_mode, formant_shifting, formant_qfrency_1, formant_timbre_1, "")
1154
- convert(pitch_2, filter_radius, index_strength_2, volume_envelope, protect, hop_length, f0method, os.path.join(output_folder, "2"), output_folder, model_pth_2, model_index_2, autotune, cleaner, clean_strength, "wav", embedder_model, resample_sr, False, f0_autotune_strength, checkpointing, onnx_f0_mode, embed_mode, formant_shifting, formant_qfrency_2, formant_timbre_2, "")
1155
-
1156
- gr_info(translations["convert_success"])
1157
- return merge_audio(processed_segments, time_stamps, input_audio, output_audio.replace("wav", export_format), export_format)
1158
- except Exception as e:
1159
- gr_error(translations["error_occurred"].format(e=e))
1160
- import traceback
1161
- logger.debug(traceback.format_exc())
1162
- return None
1163
- finally:
1164
- if os.path.exists("audios_temp"): shutil.rmtree("audios_temp", ignore_errors=True)
1165
-
1166
- def convert_tts(clean, autotune, pitch, clean_strength, model, index, index_rate, input, output, format, method, hybrid_method, hop_length, embedders, custom_embedders, resample_sr, filter_radius, volume_envelope, protect, split_audio, f0_autotune_strength, checkpointing, onnx_f0_mode, formant_shifting, formant_qfrency, formant_timbre, f0_file, embedders_mode):
1167
- model_path = os.path.join("assets", "weights", model)
1168
-
1169
- if not model_path or not os.path.exists(model_path) or os.path.isdir(model_path) or not model.endswith((".pth", ".onnx")):
1170
- gr_warning(translations["provide_file"].format(filename=translations["model"]))
1171
- return None
1172
-
1173
- if not input or not os.path.exists(input):
1174
- gr_warning(translations["input_not_valid"])
1175
- return None
1176
-
1177
- if os.path.isdir(input):
1178
- input_audio = [f for f in os.listdir(input) if "tts" in f and f.lower().endswith(("wav", "mp3", "flac", "ogg", "opus", "m4a", "mp4", "aac", "alac", "wma", "aiff", "webm", "ac3"))]
1179
-
1180
- if not input_audio:
1181
- gr_warning(translations["not_found_in_folder"])
1182
- return None
1183
-
1184
- input = os.path.join(input, input_audio[0])
1185
-
1186
- if not output:
1187
- gr_warning(translations["output_not_valid"])
1188
- return None
1189
-
1190
- output = output.replace("wav", format)
1191
- if os.path.isdir(output): output = os.path.join(output, f"tts.{format}")
1192
-
1193
- output_dir = os.path.dirname(output)
1194
- if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True)
1195
-
1196
- if os.path.exists(output): os.remove(output)
1197
-
1198
- f0method = method if method != "hybrid" else hybrid_method
1199
- embedder_model = embedders if embedders != "custom" else custom_embedders
1200
-
1201
- gr_info(translations["convert_vocal"])
1202
-
1203
- convert(pitch, filter_radius, index_rate, volume_envelope, protect, hop_length, f0method, input, output, model_path, index, autotune, clean, clean_strength, format, embedder_model, resample_sr, split_audio, f0_autotune_strength, checkpointing, onnx_f0_mode, embedders_mode, formant_shifting, formant_qfrency, formant_timbre, f0_file)
1204
-
1205
- gr_info(translations["convert_success"])
1206
- return output
1207
-
1208
- def log_read(log_file, done):
1209
- f = open(log_file, "w", encoding="utf-8")
1210
- f.close()
1211
-
1212
- while 1:
1213
- with open(log_file, "r", encoding="utf-8") as f:
1214
- yield "".join(line for line in f.readlines() if "DEBUG" not in line and line.strip() != "")
1215
-
1216
- sleep(1)
1217
- if done[0]: break
1218
-
1219
- with open(log_file, "r", encoding="utf-8") as f:
1220
- log = "".join(line for line in f.readlines() if "DEBUG" not in line and line.strip() != "")
1221
-
1222
- yield log
1223
-
1224
- def create_dataset(input_audio, output_dataset, clean_dataset, clean_strength, separator_reverb, kim_vocals_version, overlap, segments_size, denoise_mdx, skip, skip_start, skip_end, hop_length, batch_size, sample_rate):
1225
- version = 1 if kim_vocals_version == "Version-1" else 2
1226
-
1227
- gr_info(translations["start"].format(start=translations["create"]))
1228
-
1229
- p = subprocess.Popen(f'{python} main/inference/create_dataset.py --input_audio "{input_audio}" --output_dataset "{output_dataset}" --clean_dataset {clean_dataset} --clean_strength {clean_strength} --separator_reverb {separator_reverb} --kim_vocal_version {version} --overlap {overlap} --segments_size {segments_size} --mdx_hop_length {hop_length} --mdx_batch_size {batch_size} --denoise_mdx {denoise_mdx} --skip {skip} --skip_start_audios "{skip_start}" --skip_end_audios "{skip_end}" --sample_rate {sample_rate}', shell=True)
1230
- done = [False]
1231
-
1232
- threading.Thread(target=if_done, args=(done, p)).start()
1233
-
1234
- for log in log_read(os.path.join("assets", "logs", "create_dataset.log"), done):
1235
- yield log
1236
-
1237
- def preprocess(model_name, sample_rate, cpu_core, cut_preprocess, process_effects, path, clean_dataset, clean_strength):
1238
- dataset = os.path.join(path)
1239
- sr = int(float(sample_rate.rstrip("k")) * 1000)
1240
-
1241
- if not model_name: return gr_warning(translations["provide_name"])
1242
- if not any(f.lower().endswith(("wav", "mp3", "flac", "ogg", "opus", "m4a", "mp4", "aac", "alac", "wma", "aiff", "webm", "ac3")) for f in os.listdir(dataset) if os.path.isfile(os.path.join(dataset, f))): return gr_warning(translations["not_found_data"])
1243
-
1244
- model_dir = os.path.join("assets", "logs", model_name)
1245
- if os.path.exists(model_dir): shutil.rmtree(model_dir, ignore_errors=True)
1246
-
1247
- p = subprocess.Popen(f'{python} main/inference/preprocess.py --model_name "{model_name}" --dataset_path "{dataset}" --sample_rate {sr} --cpu_cores {cpu_core} --cut_preprocess {cut_preprocess} --process_effects {process_effects} --clean_dataset {clean_dataset} --clean_strength {clean_strength}', shell=True)
1248
- done = [False]
1249
-
1250
- threading.Thread(target=if_done, args=(done, p)).start()
1251
- os.makedirs(model_dir, exist_ok=True)
1252
-
1253
- for log in log_read(os.path.join(model_dir, "preprocess.log"), done):
1254
- yield log
1255
-
1256
- def extract(model_name, version, method, pitch_guidance, hop_length, cpu_cores, gpu, sample_rate, embedders, custom_embedders, onnx_f0_mode, embedders_mode):
1257
- embedder_model = embedders if embedders != "custom" else custom_embedders
1258
- sr = int(float(sample_rate.rstrip("k")) * 1000)
1259
-
1260
- if not model_name: return gr_warning(translations["provide_name"])
1261
-
1262
- model_dir = os.path.join("assets", "logs", model_name)
1263
- if not any(os.path.isfile(os.path.join(model_dir, "sliced_audios", f)) for f in os.listdir(os.path.join(model_dir, "sliced_audios"))) or not any(os.path.isfile(os.path.join(model_dir, "sliced_audios_16k", f)) for f in os.listdir(os.path.join(model_dir, "sliced_audios_16k"))): return gr_warning(translations["not_found_data_preprocess"])
1264
-
1265
- p = subprocess.Popen(f'{python} main/inference/extract.py --model_name "{model_name}" --rvc_version {version} --f0_method {method} --pitch_guidance {pitch_guidance} --hop_length {hop_length} --cpu_cores {cpu_cores} --gpu {gpu} --sample_rate {sr} --embedder_model {embedder_model} --f0_onnx {onnx_f0_mode} --embedders_mode {embedders_mode}', shell=True)
1266
- done = [False]
1267
-
1268
- threading.Thread(target=if_done, args=(done, p)).start()
1269
- os.makedirs(model_dir, exist_ok=True)
1270
-
1271
- for log in log_read(os.path.join(model_dir, "extract.log"), done):
1272
- yield log
1273
-
1274
- def create_index(model_name, rvc_version, index_algorithm):
1275
- if not model_name: return gr_warning(translations["provide_name"])
1276
- model_dir = os.path.join("assets", "logs", model_name)
1277
-
1278
- if not any(os.path.isfile(os.path.join(model_dir, f"{rvc_version}_extracted", f)) for f in os.listdir(os.path.join(model_dir, f"{rvc_version}_extracted"))): return gr_warning(translations["not_found_data_extract"])
1279
-
1280
- p = subprocess.Popen(f'{python} main/inference/create_index.py --model_name "{model_name}" --rvc_version {rvc_version} --index_algorithm {index_algorithm}', shell=True)
1281
- done = [False]
1282
-
1283
- threading.Thread(target=if_done, args=(done, p)).start()
1284
- os.makedirs(model_dir, exist_ok=True)
1285
-
1286
- for log in log_read(os.path.join(model_dir, "create_index.log"), done):
1287
- yield log
1288
-
1289
- def training(model_name, rvc_version, save_every_epoch, save_only_latest, save_every_weights, total_epoch, sample_rate, batch_size, gpu, pitch_guidance, not_pretrain, custom_pretrained, pretrain_g, pretrain_d, detector, threshold, clean_up, cache, model_author, vocoder, checkpointing, deterministic, benchmark):
1290
- sr = int(float(sample_rate.rstrip("k")) * 1000)
1291
- if not model_name: return gr_warning(translations["provide_name"])
1292
-
1293
- model_dir = os.path.join("assets", "logs", model_name)
1294
- if os.path.exists(os.path.join(model_dir, "train_pid.txt")): os.remove(os.path.join(model_dir, "train_pid.txt"))
1295
-
1296
- if not any(os.path.isfile(os.path.join(model_dir, f"{rvc_version}_extracted", f)) for f in os.listdir(os.path.join(model_dir, f"{rvc_version}_extracted"))): return gr_warning(translations["not_found_data_extract"])
1297
-
1298
- if not not_pretrain:
1299
- if not custom_pretrained:
1300
- pretrained_selector = {True: {32000: ("f0G32k.pth", "f0D32k.pth"), 40000: ("f0G40k.pth", "f0D40k.pth"), 48000: ("f0G48k.pth", "f0D48k.pth")}, False: {32000: ("G32k.pth", "D32k.pth"), 40000: ("G40k.pth", "D40k.pth"), 48000: ("G48k.pth", "D48k.pth")}}
1301
-
1302
- pg, pd = pretrained_selector[pitch_guidance][sr]
1303
- else:
1304
- if not pretrain_g: return gr_warning(translations["provide_pretrained"].format(dg="G"))
1305
- if not pretrain_d: return gr_warning(translations["provide_pretrained"].format(dg="D"))
1306
-
1307
- pg, pd = pretrain_g, pretrain_d
1308
-
1309
- pretrained_G, pretrained_D = (os.path.join("assets", "models", f"pretrained_{rvc_version}", f"{vocoder}_{pg}" if vocoder != 'Default' else pg), os.path.join("assets", "models", f"pretrained_{rvc_version}", f"{vocoder}_{pd}" if vocoder != 'Default' else pd)) if not custom_pretrained else (os.path.join("assets", "models", f"pretrained_custom", pg), os.path.join("assets", "models", f"pretrained_custom", pd))
1310
- download_version = codecs.decode(f"uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/cergenvarq_i{'2' if rvc_version == 'v2' else '1'}/", "rot13")
1311
-
1312
- if not custom_pretrained:
1313
- try:
1314
- if not os.path.exists(pretrained_G):
1315
- gr_info(translations["download_pretrained"].format(dg="G", rvc_version=rvc_version))
1316
- huggingface.HF_download_file("".join([download_version, vocoder, "_", pg]) if vocoder != 'Default' else (download_version + pg), os.path.join("assets", "models", f"pretrained_{rvc_version}", f"{vocoder}_{pg}" if vocoder != 'Default' else pg))
1317
-
1318
- if not os.path.exists(pretrained_D):
1319
- gr_info(translations["download_pretrained"].format(dg="D", rvc_version=rvc_version))
1320
- huggingface.HF_download_file("".join([download_version, vocoder, "_", pd]) if vocoder != 'Default' else (download_version + pd), os.path.join("assets", "models", f"pretrained_{rvc_version}", f"{vocoder}_{pd}" if vocoder != 'Default' else pd))
1321
- except:
1322
- gr_warning(translations["not_use_pretrain_error_download"])
1323
- pretrained_G, pretrained_D = None, None
1324
- else:
1325
- if not os.path.exists(pretrained_G): return gr_warning(translations["not_found_pretrain"].format(dg="G"))
1326
- if not os.path.exists(pretrained_D): return gr_warning(translations["not_found_pretrain"].format(dg="D"))
1327
- else: gr_warning(translations["not_use_pretrain"])
1328
-
1329
- gr_info(translations["start"].format(start=translations["training"]))
1330
-
1331
- p = subprocess.Popen(f'{python} main/inference/train.py --model_name "{model_name}" --rvc_version {rvc_version} --save_every_epoch {save_every_epoch} --save_only_latest {save_only_latest} --save_every_weights {save_every_weights} --total_epoch {total_epoch} --sample_rate {sr} --batch_size {batch_size} --gpu {gpu} --pitch_guidance {pitch_guidance} --overtraining_detector {detector} --overtraining_threshold {threshold} --cleanup {clean_up} --cache_data_in_gpu {cache} --g_pretrained_path "{pretrained_G}" --d_pretrained_path "{pretrained_D}" --model_author "{model_author}" --vocoder "{vocoder}" --checkpointing {checkpointing} --deterministic {deterministic} --benchmark {benchmark}', shell=True)
1332
- done = [False]
1333
-
1334
- with open(os.path.join(model_dir, "train_pid.txt"), "w") as pid_file:
1335
- pid_file.write(str(p.pid))
1336
-
1337
- threading.Thread(target=if_done, args=(done, p)).start()
1338
-
1339
- for log in log_read(os.path.join(model_dir, "train.log"), done):
1340
- if len(log.split("\n")) > 100: log = log[-100:]
1341
- yield log
1342
-
1343
- def stop_pid(pid_file, model_name=None, train=False):
1344
- try:
1345
- pid_file_path = os.path.join("assets", f"{pid_file}.txt") if model_name is None else os.path.join("assets", "logs", model_name, f"{pid_file}.txt")
1346
-
1347
- if not os.path.exists(pid_file_path): return gr_warning(translations["not_found_pid"])
1348
- else:
1349
- with open(pid_file_path, "r") as pid_file:
1350
- pids = [int(pid) for pid in pid_file.readlines()]
1351
-
1352
- for pid in pids:
1353
- os.kill(pid, 9)
1354
-
1355
- if os.path.exists(pid_file_path): os.remove(pid_file_path)
1356
-
1357
- pid_file_path = os.path.join("assets", "logs", model_name, "config.json")
1358
-
1359
- if train and os.path.exists(pid_file_path):
1360
- with open(pid_file_path, "r") as pid_file:
1361
- pid_data = json.load(pid_file)
1362
- pids = pid_data.get("process_pids", [])
1363
-
1364
- with open(pid_file_path, "w") as pid_file:
1365
- pid_data.pop("process_pids", None)
1366
-
1367
- json.dump(pid_data, pid_file, indent=4)
1368
-
1369
- for pid in pids:
1370
- os.kill(pid, 9)
1371
-
1372
- gr_info(translations["end_pid"])
1373
- except:
1374
- pass
1375
-
1376
- def load_presets(presets, cleaner, autotune, pitch, clean_strength, index_strength, resample_sr, filter_radius, volume_envelope, protect, split_audio, f0_autotune_strength, formant_shifting, formant_qfrency, formant_timbre):
1377
- if not presets: return gr_warning(translations["provide_file_settings"])
1378
-
1379
- with open(os.path.join("assets", "presets", presets)) as f:
1380
- file = json.load(f)
1381
-
1382
- gr_info(translations["load_presets"].format(presets=presets))
1383
- return file.get("cleaner", cleaner), file.get("autotune", autotune), file.get("pitch", pitch), file.get("clean_strength", clean_strength), file.get("index_strength", index_strength), file.get("resample_sr", resample_sr), file.get("filter_radius", filter_radius), file.get("volume_envelope", volume_envelope), file.get("protect", protect), file.get("split_audio", split_audio), file.get("f0_autotune_strength", f0_autotune_strength), file.get("formant_shifting", formant_shifting), file.get("formant_qfrency", formant_qfrency), file.get("formant_timbre", formant_timbre)
1384
-
1385
- def save_presets(name, cleaner, autotune, pitch, clean_strength, index_strength, resample_sr, filter_radius, volume_envelope, protect, split_audio, f0_autotune_strength, cleaner_chbox, autotune_chbox, pitch_chbox, index_strength_chbox, resample_sr_chbox, filter_radius_chbox, volume_envelope_chbox, protect_chbox, split_audio_chbox, formant_shifting_chbox, formant_shifting, formant_qfrency, formant_timbre):
1386
- if not name: return gr_warning(translations["provide_filename_settings"])
1387
- if not any([cleaner_chbox, autotune_chbox, pitch_chbox, index_strength_chbox, resample_sr_chbox, filter_radius_chbox, volume_envelope_chbox, protect_chbox, split_audio_chbox, formant_shifting_chbox]): return gr_warning(translations["choose1"])
1388
-
1389
- settings = {}
1390
-
1391
- for checkbox, data in [(cleaner_chbox, {"cleaner": cleaner, "clean_strength": clean_strength}), (autotune_chbox, {"autotune": autotune, "f0_autotune_strength": f0_autotune_strength}), (pitch_chbox, {"pitch": pitch}), (index_strength_chbox, {"index_strength": index_strength}), (resample_sr_chbox, {"resample_sr": resample_sr}), (filter_radius_chbox, {"filter_radius": filter_radius}), (volume_envelope_chbox, {"volume_envelope": volume_envelope}), (protect_chbox, {"protect": protect}), (split_audio_chbox, {"split_audio": split_audio}), (formant_shifting_chbox, {"formant_shifting": formant_shifting, "formant_qfrency": formant_qfrency, "formant_timbre": formant_timbre})]:
1392
- if checkbox: settings.update(data)
1393
-
1394
- with open(os.path.join("assets", "presets", name + ".json"), "w") as f:
1395
- json.dump(settings, f, indent=4)
1396
-
1397
- gr_info(translations["export_settings"])
1398
- return change_preset_choices()
1399
-
1400
- def report_bug(error_info, provide):
1401
- report_path = os.path.join("assets", "logs", "report_bugs.log")
1402
- if os.path.exists(report_path): os.remove(report_path)
1403
-
1404
- report_url = codecs.decode(requests.get(codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/jroubbx.gkg", "rot13")).text, "rot13")
1405
- if not error_info: error_info = "Không Có"
1406
-
1407
- gr_info(translations["thank"])
1408
-
1409
- if provide:
1410
- try:
1411
- for log in [os.path.join(root, name) for root, _, files in os.walk(os.path.join("assets", "logs"), topdown=False) for name in files if name.endswith(".log")]:
1412
- with open(log, "r", encoding="utf-8") as r:
1413
- with open(report_path, "a", encoding="utf-8") as w:
1414
- w.write(str(r.read()))
1415
- w.write("\n")
1416
- except Exception as e:
1417
- gr_error(translations["error_read_log"])
1418
- logger.debug(e)
1419
-
1420
- try:
1421
- with open(report_path, "r", encoding="utf-8") as f:
1422
- content = f.read()
1423
-
1424
- requests.post(report_url, json={"embeds": [{"title": "Báo Cáo Lỗi", "description": f"Mô tả lỗi: {error_info}", "color": 15158332, "author": {"name": "Vietnamese_RVC", "icon_url": codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/vpb.cat", "rot13"), "url": codecs.decode("uggcf://tvguho.pbz/CunzUhlauNau16/Ivrganzrfr-EIP/gerr/znva","rot13")}, "thumbnail": {"url": codecs.decode("uggcf://p.grabe.pbz/7dADJbv-36fNNNNq/grabe.tvs", "rot13")}, "fields": [{"name": "Số Lượng Gỡ Lỗi", "value": content.count("DEBUG")}, {"name": "Số Lượng Thông Tin", "value": content.count("INFO")}, {"name": "Số Lượng Cảnh Báo", "value": content.count("WARNING")}, {"name": "Số Lượng Lỗi", "value": content.count("ERROR")}], "footer": {"text": f"Tên Máy: {platform.uname().node} - Hệ Điều Hành: {platform.system()}-{platform.version()}\nThời Gian Báo Cáo Lỗi: {datetime.datetime.now()}."}}]})
1425
-
1426
- with open(report_path, "rb") as f:
1427
- requests.post(report_url, files={"file": f})
1428
- except Exception as e:
1429
- gr_error(translations["error_send"])
1430
- logger.debug(e)
1431
- finally:
1432
- if os.path.exists(report_path): os.remove(report_path)
1433
- else: requests.post(report_url, json={"embeds": [{"title": "Báo Cáo Lỗi", "description": error_info}]})
1434
-
1435
- def f0_extract(audio, f0_method, f0_onnx):
1436
- if not audio or not os.path.exists(audio) or os.path.isdir(audio):
1437
- gr_warning(translations["input_not_valid"])
1438
- return [None]*2
1439
-
1440
- from matplotlib import pyplot as plt
1441
- from main.library.utils import check_predictors
1442
- from main.inference.extract import FeatureInput
1443
-
1444
- check_predictors(f0_method, f0_onnx)
1445
-
1446
- f0_path = os.path.join("assets", "f0", os.path.splitext(os.path.basename(audio))[0])
1447
- image_path = os.path.join(f0_path, "f0.png")
1448
- txt_path = os.path.join(f0_path, "f0.txt")
1449
-
1450
- gr_info(translations["start_extract"])
1451
-
1452
- if not os.path.exists(f0_path): os.makedirs(f0_path, exist_ok=True)
1453
-
1454
- y, sr = librosa.load(audio, sr=None)
1455
-
1456
- feats = FeatureInput(sample_rate=sr, is_half=config.is_half, device=config.device)
1457
- feats.f0_max = 1600.0
1458
-
1459
- F_temp = np.array(feats.compute_f0(y.flatten(), f0_method, 160, f0_onnx), dtype=np.float32)
1460
- F_temp[F_temp == 0] = np.nan
1461
-
1462
- f0 = 1200 * np.log2(F_temp / librosa.midi_to_hz(0))
1463
-
1464
- plt.figure(figsize=(10, 4))
1465
- plt.plot(f0)
1466
- plt.title(f0_method)
1467
- plt.xlabel(translations["time_frames"])
1468
- plt.ylabel(translations["Frequency"])
1469
- plt.savefig(image_path)
1470
- plt.close()
1471
-
1472
- with open(txt_path, "w") as f:
1473
- for i, f0_value in enumerate(f0):
1474
- f.write(f"{i * sr / 160},{f0_value}\n")
1475
-
1476
- gr_info(translations["extract_done"])
1477
-
1478
- return [txt_path, image_path]
1479
-
1480
- def pitch_guidance_lock(vocoders):
1481
- return {"value": True, "interactive": vocoders == "Default", "__type__": "update"}
1482
-
1483
- def vocoders_lock(pitch, vocoders):
1484
- return {"value": vocoders if pitch else "Default", "interactive": pitch, "__type__": "update"}
1485
-
1486
- def run_audioldm2(input_path, output_path, export_format, sample_rate, audioldm_model, source_prompt, target_prompt, steps, cfg_scale_src, cfg_scale_tar, t_start, save_compute):
1487
- if not input_path or not os.path.exists(input_path) or os.path.isdir(input_path):
1488
- gr_warning(translations["input_not_valid"])
1489
- return None
1490
-
1491
- if not output_path:
1492
- gr_warning(translations["output_not_valid"])
1493
- return None
1494
-
1495
- output_path = output_path.replace("wav", export_format)
1496
-
1497
- if os.path.exists(output_path): os.remove(output_path)
1498
-
1499
- gr_info(translations["start_edit"].format(input_path=input_path))
1500
- subprocess.run([python, "main/inference/audioldm2.py", "--input_path", input_path, "--output_path", output_path, "--export_format", str(export_format), "--sample_rate", str(sample_rate), "--audioldm_model", audioldm_model, "--source_prompt", source_prompt, "--target_prompt", target_prompt, "--steps", str(steps), "--cfg_scale_src", str(cfg_scale_src), "--cfg_scale_tar", str(cfg_scale_tar), "--t_start", str(t_start), "--save_compute", str(save_compute)])
1501
-
1502
- gr_info(translations["success"])
1503
- return output_path
1504
-
1505
- def change_fp(fp):
1506
- fp16 = fp == "fp16"
1507
-
1508
- if fp16 and config.device == "cpu":
1509
- gr_warning(translations["fp16_not_support"])
1510
- return "fp32"
1511
- else:
1512
- gr_info(translations["start_update_precision"])
1513
-
1514
- configs = json.load(open(configs_json, "r"))
1515
- configs["fp16"] = config.is_half = fp16
1516
-
1517
- with open(configs_json, "w") as f:
1518
- json.dump(configs, f, indent=4)
1519
-
1520
- gr_info(translations["success"])
1521
- return "fp16" if fp16 else "fp32"
1522
-
1523
- def unlock_f0(value):
1524
- return {"choices": method_f0_full if value else method_f0, "value": "rmvpe", "__type__": "update"}
1525
-
1526
- def unlock_vocoder(value, vocoder):
1527
- return {"value": vocoder if value == "v2" else "Default", "interactive": value == "v2", "__type__": "update"}
1528
-
1529
- def unlock_ver(value, vocoder):
1530
- return {"value": "v2" if vocoder == "Default" else value, "interactive": vocoder == "Default", "__type__": "update"}
1531
-
1532
- def visible_embedders(value):
1533
- return {"visible": value != "spin", "__type__": "update"}
1534
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/app/parser.py DELETED
@@ -1,339 +0,0 @@
1
- import os
2
- import sys
3
-
4
- sys.path.append(os.getcwd())
5
-
6
- try:
7
- argv = sys.argv[1]
8
- except IndexError:
9
- argv = None
10
-
11
- argv_is_allows = ["--audio_effects", "--audioldm2", "--convert", "--create_dataset", "--create_index", "--extract", "--preprocess", "--separator_music", "--train", "--help_audio_effects", "--help_audioldm2", "--help_convert", "--help_create_dataset", "--help_create_index", "--help_extract", "--help_preprocess", "--help_separator_music", "--help_train", "--help"]
12
-
13
- if argv not in argv_is_allows:
14
- print("Invalid syntax! Use --help for more information")
15
- quit()
16
-
17
- if argv_is_allows[0] in argv: from main.inference.audio_effects import main
18
- elif argv_is_allows[1] in argv: from main.inference.audioldm2 import main
19
- elif argv_is_allows[2] in argv: from main.inference.convert import main
20
- elif argv_is_allows[3] in argv: from main.inference.create_dataset import main
21
- elif argv_is_allows[4] in argv: from main.inference.create_index import main
22
- elif argv_is_allows[5] in argv: from main.inference.extract import main
23
- elif argv_is_allows[6] in argv: from main.inference.preprocess import main
24
- elif argv_is_allows[7] in argv: from main.inference.separator_music import main
25
- elif argv_is_allows[8] in argv: from main.inference.train import main
26
- elif argv_is_allows[9] in argv:
27
- print("""Parameters for `--audio_effects`:
28
- 1. File paths:
29
- - `--input_path` (required): Path to the input audio file.
30
- - `--output_path` (default: `./audios/apply_effects.wav`): Path to save the output file.
31
- - `--export_format` (default: `wav`): Output file format (`wav`, `mp3`, ...).
32
-
33
- 2. Resampling:
34
- - `--resample` (default: `False`): Whether to resample or not.
35
- - `--resample_sr` (default: `0`): New sampling frequency (Hz).
36
-
37
- 3. Chorus effect:
38
- - `--chorus`: Enable/disable chorus.
39
- - `--chorus_depth`, `--chorus_rate`, `--chorus_mix`, `--chorus_delay`, `--chorus_feedback`: Parameters to adjust chorus.
40
-
41
- 4. Distortion effect:
42
- - `--distortion`: Enable/disable distortion.
43
- - `--drive_db`: Degree of audio distortion.
44
-
45
- 5. Reverb effect:
46
- - `--reverb`: Enable/disable reverb.
47
- - `--reverb_room_size`, `--reverb_damping`, `--reverb_wet_level`, `--reverb_dry_level`, `--reverb_width`, `--reverb_freeze_mode`: Adjust reverb.
48
-
49
- 6. Pitch shift effect:
50
- - `--pitchshift`: Enable/disable pitch shift.
51
- - `--pitch_shift`: Pitch shift value.
52
-
53
- 7. Delay effect:
54
- - `--delay`: Enable/disable delay.
55
- - `--delay_seconds`, `--delay_feedback`, `--delay_mix`: Adjust delay time, feedback, and mix.
56
-
57
- 8. Compressor:
58
- - `--compressor`: Enable/disable compressor.
59
- - `--compressor_threshold`, `--compressor_ratio`, `--compressor_attack_ms`, `--compressor_release_ms`: Compression parameters.
60
-
61
- 9. Limiter:
62
- - `--limiter`: Enable/disable audio level limiter.
63
- - `--limiter_threshold`, `--limiter_release`: Limiter threshold and release time.
64
-
65
- 10. Gain (Amplification):
66
- - `--gain`: Enable/disable gain.
67
- - `--gain_db`: Gain level (dB).
68
-
69
- 11. Bitcrush:
70
- - `--bitcrush`: Enable/disable bit resolution reduction effect.
71
- - `--bitcrush_bit_depth`: Bit depth for bitcrush.
72
-
73
- 12. Clipping:
74
- - `--clipping`: Enable/disable audio clipping.
75
- - `--clipping_threshold`: Clipping threshold.
76
-
77
- 13. Phaser:
78
- - `--phaser`: Enable/disable phaser effect.
79
- - `--phaser_rate_hz`, `--phaser_depth`, `--phaser_centre_frequency_hz`, `--phaser_feedback`, `--phaser_mix`: Adjust phaser effect.
80
-
81
- 14. Boost bass & treble:
82
- - `--treble_bass_boost`: Enable/disable bass and treble boost.
83
- - `--bass_boost_db`, `--bass_boost_frequency`, `--treble_boost_db`, `--treble_boost_frequency`: Bass and treble boost parameters.
84
-
85
- 15. Fade in & fade out:
86
- - `--fade_in_out`: Enable/disable fade effect.
87
- - `--fade_in_duration`, `--fade_out_duration`: Fade in/out duration.
88
-
89
- 16. Audio combination:
90
- - `--audio_combination`: Enable/disable combining multiple audio files.
91
- - `--audio_combination_input`: Path to additional audio files.
92
- """)
93
- quit()
94
- elif argv_is_allows[10] in argv:
95
- print("""Parameters for `--audioldm2`:
96
- 1. File paths:
97
- - `--input_path` (required): Path to the input audio file.
98
- - `--output_path` (default: `./output.wav`): Path to save the output file.
99
- - `--export_format` (default: `wav`): Output file format.
100
-
101
- 2. Audio configuration:
102
- - `--sample_rate` (default: `44100`): Sampling frequency (Hz).
103
-
104
- 3. AudioLDM model configuration:
105
- - `--audioldm_model` (default: `audioldm2-music`): Select AudioLDM model for processing.
106
-
107
- 4. Model guidance prompt:
108
- - `--source_prompt` (default: ``): Description of source audio.
109
- - `--target_prompt` (default: ``): Description of target audio.
110
-
111
- 5. Processing algorithm configuration:
112
- - `--steps` (default: `200`): Number of steps in audio synthesis process.
113
- - `--cfg_scale_src` (default: `3.5`): Guidance scale for source audio.
114
- - `--cfg_scale_tar` (default: `12`): Guidance scale for target audio.
115
- - `--t_start` (default: `45`): Editing level.
116
-
117
- 6. Computation optimization:
118
- - `--save_compute` (default: `False`): Whether to enable compute optimization mode.
119
- """)
120
- quit()
121
- elif argv_is_allows[11] in argv:
122
- print("""Parameters for `--convert`:
123
- 1. Voice processing configuration:
124
- - `--pitch` (default: `0`): Adjust pitch.
125
- - `--filter_radius` (default: `3`): F0 curve smoothness.
126
- - `--index_rate` (default: `0.5`): Voice index usage rate.
127
- - `--volume_envelope` (default: `1`): Volume amplitude adjustment factor.
128
- - `--protect` (default: `0.33`): Consonant protection.
129
-
130
- 2. Frame hop configuration:
131
- - `--hop_length` (default: `64`): Hop length during audio processing.
132
-
133
- 3. F0 configuration:
134
- - `--f0_method` (default: `rmvpe`): F0 prediction method (`pm`, `dio`, `mangio-crepe-tiny`, `mangio-crepe-small`, `mangio-crepe-medium`, `mangio-crepe-large`, `mangio-crepe-full`, `crepe-tiny`, `crepe-small`, `crepe-medium`, `crepe-large`, `crepe-full`, `fcpe`, `fcpe-legacy`, `rmvpe`, `rmvpe-legacy`, `harvest`, `yin`, `pyin`, `swipe`).
135
- - `--f0_autotune` (default: `False`): Whether to auto-tune F0.
136
- - `--f0_autotune_strength` (default: `1`): Strength of F0 auto-tuning.
137
- - `--f0_file` (default: ``): Path to existing F0 file.
138
- - `--f0_onnx` (default: `False`): Whether to use ONNX version of F0.
139
-
140
- 4. Embedding model:
141
- - `--embedder_model` (default: `contentvec_base`): Embedding model used.
142
- - `--embedders_mode` (default: `fairseq`): Embedding mode (`fairseq`, `transformers`, `onnx`).
143
-
144
- 5. File paths:
145
- - `--input_path` (required): Path to input audio file.
146
- - `--output_path` (default: `./audios/output.wav`): Path to save output file.
147
- - `--export_format` (default: `wav`): Output file format.
148
- - `--pth_path` (required): Path to `.pth` model file.
149
- - `--index_path` (default: `None`): Path to index file (if any).
150
-
151
- 6. Audio cleaning:
152
- - `--clean_audio` (default: `False`): Whether to apply audio cleaning.
153
- - `--clean_strength` (default: `0.7`): Cleaning strength.
154
-
155
- 7. Resampling & audio splitting:
156
- - `--resample_sr` (default: `0`): New sampling frequency (0 means keep original).
157
- - `--split_audio` (default: `False`): Whether to split audio before processing.
158
-
159
- 8. Testing & optimization:
160
- - `--checkpointing` (default: `False`): Enable/disable checkpointing to save RAM.
161
-
162
- 9. Formant shifting:
163
- - `--formant_shifting` (default: `False`): Whether to enable formant shifting effect.
164
- - `--formant_qfrency` (default: `0.8`): Formant shift frequency factor.
165
- - `--formant_timbre` (default: `0.8`): Voice timbre change factor.
166
- """)
167
- quit()
168
- elif argv_is_allows[12] in argv:
169
- print("""Parameters for `--create_dataset`:
170
- 1. Dataset paths & configuration:
171
- - `--input_audio` (required): Path to audio link (YouTube link, can use `,` for multiple links).
172
- - `--output_dataset` (default: `./dataset`): Output data directory.
173
- - `--sample_rate` (default: `44100`): Audio sampling frequency.
174
-
175
- 2. Data cleaning:
176
- - `--clean_dataset` (default: `False`): Whether to apply data cleaning.
177
- - `--clean_strength` (default: `0.7`): Data cleaning strength.
178
-
179
- 3. Voice separation & effects:
180
- - `--separator_reverb` (default: `False`): Whether to separate voice reverb.
181
- - `--kim_vocal_version` (default: `2`): Kim Vocal model version for separation (`1`, `2`).
182
-
183
- 4. Audio segmentation configuration:
184
- - `--overlap` (default: `0.25`): Overlap level between segments during separation.
185
- - `--segments_size` (default: `256`): Size of each segment.
186
-
187
- 5. MDX (Music Demixing) configuration:
188
- - `--mdx_hop_length` (default: `1024`): MDX hop length during processing.
189
- - `--mdx_batch_size` (default: `1`): Batch size during MDX processing.
190
- - `--denoise_mdx` (default: `False`): Whether to apply denoising during MDX separation.
191
-
192
- 6. Skip audio sections:
193
- - `--skip` (default: `False`): Whether to skip any audio seconds.
194
- - `--skip_start_audios` (default: `0`): Time (seconds) to skip at the start of audio.
195
- - `--skip_end_audios` (default: `0`): Time (seconds) to skip at the end of audio.
196
- """)
197
- quit()
198
- elif argv_is_allows[13] in argv:
199
- print("""Parameters for `--create_index`:
200
- 1. Model information:
201
- - `--model_name` (required): Model name.
202
- - `--rvc_version` (default: `v2`): Version (`v1`, `v2`).
203
- - `--index_algorithm` (default: `Auto`): Index algorithm used (`Auto`, `Faiss`, `KMeans`).
204
- """)
205
- quit()
206
- elif argv_is_allows[14] in argv:
207
- print("""Parameters for `--extract`:
208
- 1. Model information:
209
- - `--model_name` (required): Model name.
210
- - `--rvc_version` (default: `v2`): RVC version (`v1`, `v2`).
211
-
212
- 2. F0 configuration:
213
- - `--f0_method` (default: `rmvpe`): F0 prediction method (`pm`, `dio`, `mangio-crepe-tiny`, `mangio-crepe-small`, `mangio-crepe-medium`, `mangio-crepe-large`, `mangio-crepe-full`, `crepe-tiny`, `crepe-small`, `crepe-medium`, `crepe-large`, `crepe-full`, `fcpe`, `fcpe-legacy`, `rmvpe`, `rmvpe-legacy`, `harvest`, `yin`, `pyin`, `swipe`).
214
- - `--pitch_guidance` (default: `True`): Whether to use pitch guidance.
215
-
216
- 3. Processing configuration:
217
- - `--hop_length` (default: `128`): Hop length during processing.
218
- - `--cpu_cores` (default: `2`): Number of CPU threads used.
219
- - `--gpu` (default: `-`): Specify GPU to use (e.g., `0` for first GPU, `-` to disable GPU).
220
- - `--sample_rate` (required): Input audio sampling frequency.
221
-
222
- 4. Embedding configuration:
223
- - `--embedder_model` (default: `contentvec_base`): Embedding model name.
224
- - `--f0_onnx` (default: `False`): Whether to use ONNX version of F0.
225
- - `--embedders_mode` (default: `fairseq`): Embedding mode (`fairseq`, `transformers`, `onnx`).
226
- """)
227
- quit()
228
- elif argv_is_allows[15] in argv:
229
- print("""Parameters for `--preprocess`:
230
- 1. Model information:
231
- - `--model_name` (required): Model name.
232
-
233
- 2. Data configuration:
234
- - `--dataset_path` (default: `./dataset`): Path to directory containing data files.
235
- - `--sample_rate` (required): Audio data sampling frequency.
236
-
237
- 3. Processing configuration:
238
- - `--cpu_cores` (default: `2`): Number of CPU threads used.
239
- - `--cut_preprocess` (default: `True`): Whether to cut data files.
240
- - `--process_effects` (default: `False`): Whether to apply preprocessing.
241
- - `--clean_dataset` (default: `False`): Whether to clean data files.
242
- - `--clean_strength` (default: `0.7`): Data cleaning strength.
243
- """)
244
- quit()
245
- elif argv_is_allows[16] in argv:
246
- print("""Parameters for `--separator_music`:
247
- 1. Data paths:
248
- - `--input_path` (required): Path to input audio file.
249
- - `--output_path` (default: `./audios`): Directory to save output files.
250
- - `--format` (default: `wav`): Output file format (`wav`, `mp3`, ...).
251
-
252
- 2. Audio processing configuration:
253
- - `--shifts` (default: `2`): Number of predictions.
254
- - `--segments_size` (default: `256`): Audio segment size.
255
- - `--overlap` (default: `0.25`): Overlap level between segments.
256
- - `--mdx_hop_length` (default: `1024`): MDX hop length during processing.
257
- - `--mdx_batch_size` (default: `1`): Batch size.
258
-
259
- 3. Cleaning processing:
260
- - `--clean_audio` (default: `False`): Whether to clean audio.
261
- - `--clean_strength` (default: `0.7`): Cleaning filter strength.
262
-
263
- 4. Model configuration:
264
- - `--model_name` (default: `HT-Normal`): Music separation model (`Main_340`, `Main_390`, `Main_406`, `Main_427`, `Main_438`, `Inst_full_292`, `Inst_HQ_1`, `Inst_HQ_2`, `Inst_HQ_3`, `Inst_HQ_4`, `Inst_HQ_5`, `Kim_Vocal_1`, `Kim_Vocal_2`, `Kim_Inst`, `Inst_187_beta`, `Inst_82_beta`, `Inst_90_beta`, `Voc_FT`, `Crowd_HQ`, `Inst_1`, `Inst_2`, `Inst_3`, `MDXNET_1_9703`, `MDXNET_2_9682`, `MDXNET_3_9662`, `Inst_Main`, `MDXNET_Main`, `MDXNET_9482`, `HT-Normal`, `HT-Tuned`, `HD_MMI`, `HT_6S`).
265
- - `--kara_model` (default: `Version-1`): Backing track separation model version (`Version-1`, `Version-2`).
266
-
267
- 5. Effects and post-processing:
268
- - `--backing` (default: `False`): Whether to separate backing vocals.
269
- - `--mdx_denoise` (default: `False`): Whether to use MDX denoising.
270
- - `--reverb` (default: `False`): Whether to separate reverb.
271
- - `--backing_reverb` (default: `False`): Whether to separate reverb for backing vocals.
272
-
273
- 6. Sampling frequency:
274
- - `--sample_rate` (default: `44100`): Output audio sampling frequency.
275
- """)
276
- quit()
277
- elif argv_is_allows[17] in argv:
278
- print("""Parameters for `--train`:
279
- 1. Model configuration:
280
- - `--model_name` (required): Model name.
281
- - `--rvc_version` (default: `v2`): RVC version (`v1`, `v2`).
282
- - `--model_author` (optional): Model author.
283
-
284
- 2. Save configuration:
285
- - `--save_every_epoch` (required): Number of epochs between saves.
286
- - `--save_only_latest` (default: `True`): Save only the latest checkpoint.
287
- - `--save_every_weights` (default: `True`): Save all model weights.
288
-
289
- 3. Training configuration:
290
- - `--total_epoch` (default: `300`): Total number of training epochs.
291
- - `--batch_size` (default: `8`): Batch size during training.
292
- - `--sample_rate` (required): Audio sampling frequency.
293
-
294
- 4. Device configuration:
295
- - `--gpu` (default: `0`): Specify GPU to use (agena: Specify GPU to use (number or `-` if not using GPU).
296
- - `--cache_data_in_gpu` (default: `False`): Cache data in GPU for faster processing.
297
-
298
- 5. Advanced training configuration:
299
- - `--pitch_guidance` (default: `True`): Use pitch guidance.
300
- - `--g_pretrained_path` (default: ``): Path to pretrained G weights.
301
- - `--d_pretrained_path` (default: ``): Path to pretrained D weights.
302
- - `--vocoder` (default: `Default`): Vocoder used (`Default`, `MRF-HiFi-GAN`, `RefineGAN`).
303
-
304
- 6. Overtraining detection:
305
- - `--overtraining_detector` (default: `False`): Enable/disable overtraining detection.
306
- - `--overtraining_threshold` (default: `50`): Threshold for detecting overtraining.
307
-
308
- 7. Data processing:
309
- - `--cleanup` (default: `False`): Clean old training files to start training from scratch.
310
-
311
- 8. Optimization:
312
- - `--checkpointing` (default: `False`): Enable/disable checkpointing to save RAM.
313
- - `--deterministic` (default: `False`): When enabled, uses deterministic algorithms to ensure consistent results for the same input data.
314
- - `--benchmark` (default: `False`): When enabled, tests and selects the optimal algorithm for the hardware and specific size.
315
- """)
316
- quit()
317
- elif argv_is_allows[18] in argv:
318
- print("""Usage:
319
- 1. `--help_audio_effects`: Help for adding audio effects.
320
- 2. `--help_audioldm2`: Help for music editing.
321
- 3. `--help_convert`: Help for audio conversion.
322
- 4. `--help_create_dataset`: Help for creating training data.
323
- 5. `--help_create_index`: Help for creating an index.
324
- 6. `--help_extract`: Help for extracting training data.
325
- 7. `--help_preprocess`: Help for preprocessing data.
326
- 8. `--help_separator_music`: Help for music separation.
327
- 9. `--help_train`: Help for model training.
328
- """)
329
- quit()
330
-
331
- if __name__ == "__main__":
332
- if "--train" in argv:
333
- import torch.multiprocessing as mp
334
- mp.set_start_method("spawn")
335
-
336
- try:
337
- main()
338
- except:
339
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/app/tabs/inference/inference.py DELETED
@@ -1,596 +0,0 @@
1
- import gradio as gr
2
- from main.tools import huggingface
3
- from main.configs.config import Config
4
- from main.app.based.utils import *
5
-
6
- def inference_tabs():
7
- # Audio Conversion Tab
8
- with gr.TabItem(translations["convert_audio"], visible=configs.get("convert_tab", True)):
9
- gr.Markdown(f"## {translations['convert_audio']}")
10
- with gr.Row():
11
- gr.Markdown(translations["convert_info"])
12
-
13
- with gr.Row():
14
- with gr.Column():
15
- with gr.Accordion(translations["model_accordion"], open=True):
16
- with gr.Row(equal_height=True):
17
- model_pth = gr.Dropdown(label=translations["model_name"], choices=model_name, value=model_name[0] if len(model_name) >= 1 else "", interactive=True, allow_custom_value=True)
18
- model_index = gr.Dropdown(label=translations["index_path"], choices=index_path, value=index_path[0] if len(index_path) >= 1 else "", interactive=True, allow_custom_value=True)
19
- refesh = gr.Button(translations["refesh"])
20
-
21
- with gr.Row():
22
- with gr.Column():
23
- audio_select = gr.Dropdown(label=translations["select_separate"], choices=[], value="", interactive=True, allow_custom_value=True, visible=False)
24
- convert_button_2 = gr.Button(translations["convert_audio"], visible=False)
25
- with gr.Row():
26
- with gr.Column():
27
- input0 = gr.File(label=translations["drop_audio"], file_types=[".wav", ".mp3", ".flac", ".ogg", ".opus", ".m4a", ".mp4", ".aac", ".alac", ".wma", ".aiff", ".webm", ".ac3"])
28
- play_audio = gr.Audio(show_download_button=True, interactive=False, label=translations["input_audio"])
29
- with gr.Row():
30
- with gr.Column():
31
- with gr.Row():
32
- index_strength = gr.Slider(label=translations["index_strength"], info=translations["index_strength_info"], minimum=0, maximum=1, value=0.5, step=0.01, interactive=True, visible=model_index.value != "")
33
- with gr.Row():
34
- with gr.Column():
35
- with gr.Accordion(translations["input_output"], open=False):
36
- with gr.Column():
37
- export_format = gr.Radio(label=translations["export_format"], info=translations["export_info"], choices=["wav", "mp3", "flac", "ogg", "opus", "m4a", "mp4", "aac", "alac", "wma", "aiff", "webm", "ac3"], value="wav", interactive=True)
38
- input_audio0 = gr.Dropdown(label=translations["audio_path"], value="", choices=paths_for_files, info=translations["provide_audio"], allow_custom_value=True, interactive=True)
39
- output_audio = gr.Textbox(label=translations["output_path"], value="audios/output.wav", placeholder="audios/output.wav", info=translations["output_path_info"], interactive=True)
40
- with gr.Column():
41
- refesh0 = gr.Button(translations["refesh"])
42
- with gr.Accordion(translations["setting"], open=False):
43
- with gr.Row():
44
- cleaner0 = gr.Checkbox(label=translations["clear_audio"], value=False, interactive=True)
45
- autotune = gr.Checkbox(label=translations["autotune"], value=False, interactive=True)
46
- use_audio = gr.Checkbox(label=translations["use_audio"], value=False, interactive=True)
47
- checkpointing = gr.Checkbox(label=translations["memory_efficient_training"], value=False, interactive=True)
48
- with gr.Row():
49
- use_original = gr.Checkbox(label=translations["convert_original"], value=False, interactive=True, visible=use_audio.value)
50
- convert_backing = gr.Checkbox(label=translations["convert_backing"], value=False, interactive=True, visible=use_audio.value)
51
- not_merge_backing = gr.Checkbox(label=translations["not_merge_backing"], value=False, interactive=True, visible=use_audio.value)
52
- merge_instrument = gr.Checkbox(label=translations["merge_instruments"], value=False, interactive=True, visible=use_audio.value)
53
- with gr.Row():
54
- pitch = gr.Slider(minimum=-20, maximum=20, step=1, info=translations["pitch_info"], label=translations["pitch"], value=0, interactive=True)
55
- clean_strength0 = gr.Slider(label=translations["clean_strength"], info=translations["clean_strength_info"], minimum=0, maximum=1, value=0.5, step=0.1, interactive=True, visible=cleaner0.value)
56
-
57
- with gr.Accordion(translations["f0_method"], open=False):
58
- with gr.Group():
59
- with gr.Row():
60
- onnx_f0_mode = gr.Checkbox(label=translations["f0_onnx_mode"], info=translations["f0_onnx_mode_info"], value=False, interactive=True)
61
- unlock_full_method = gr.Checkbox(label=translations["f0_unlock"], info=translations["f0_unlock_info"], value=False, interactive=True)
62
- method = gr.Radio(label=translations["f0_method"], info=translations["f0_method_info"], choices=method_f0+["hybrid"], value="rmvpe", interactive=True)
63
- hybrid_method = gr.Dropdown(label=translations["f0_method_hybrid"], info=translations["f0_method_hybrid_info"], choices=["hybrid[pm+dio]", "hybrid[pm+crepe-tiny]", "hybrid[pm+crepe]", "hybrid[pm+fcpe]", "hybrid[pm+rmvpe]", "hybrid[pm+harvest]", "hybrid[pm+yin]", "hybrid[dio+crepe-tiny]", "hybrid[dio+crepe]", "hybrid[dio+fcpe]", "hybrid[dio+rmvpe]", "hybrid[dio+harvest]", "hybrid[dio+yin]", "hybrid[crepe-tiny+crepe]", "hybrid[crepe-tiny+fcpe]", "hybrid[crepe-tiny+rmvpe]", "hybrid[crepe-tiny+harvest]", "hybrid[crepe+fcpe]", "hybrid[crepe+rmvpe]", "hybrid[crepe+harvest]", "hybrid[crepe+yin]", "hybrid[fcpe+rmvpe]", "hybrid[fcpe+harvest]", "hybrid[fcpe+yin]", "hybrid[rmvpe+harvest]", "hybrid[rmvpe+yin]", "hybrid[harvest+yin]"], value="hybrid[pm+dio]", interactive=True, allow_custom_value=True, visible=method.value == "hybrid")
64
- hop_length = gr.Slider(label="Hop length", info=translations["hop_length_info"], minimum=1, maximum=512, value=128, step=1, interactive=True, visible=False)
65
- with gr.Accordion(translations["f0_file"], open=False):
66
- upload_f0_file = gr.File(label=translations["upload_f0"], file_types=[".txt"])
67
- f0_file_dropdown = gr.Dropdown(label=translations["f0_file_2"], value="", choices=f0_file, allow_custom_value=True, interactive=True)
68
- refesh_f0_file = gr.Button(translations["refesh"])
69
- with gr.Accordion(translations["hubert_model"], open=False):
70
- embed_mode = gr.Radio(label=translations["embed_mode"], info=translations["embed_mode_info"], value="fairseq", choices=embedders_mode, interactive=True, visible=True)
71
- embedders = gr.Radio(label=translations["hubert_model"], info=translations["hubert_info"], choices=embedders_model, value="hubert_base", interactive=True)
72
- custom_embedders = gr.Textbox(label=translations["modelname"], info=translations["modelname_info"], value="", placeholder="hubert_base", interactive=True, visible=embedders.value == "custom")
73
- with gr.Accordion(translations["use_presets"], open=False):
74
- with gr.Row():
75
- presets_name = gr.Dropdown(label=translations["file_preset"], choices=presets_file, value=presets_file[0] if len(presets_file) > 0 else '', interactive=True, allow_custom_value=True)
76
- with gr.Row():
77
- load_click = gr.Button(translations["load_file"], variant="primary")
78
- refesh_click = gr.Button(translations["refesh"])
79
- with gr.Accordion(translations["export_file"], open=False):
80
- with gr.Row():
81
- with gr.Column():
82
- with gr.Group():
83
- with gr.Row():
84
- cleaner_chbox = gr.Checkbox(label=translations["save_clean"], value=True, interactive=True)
85
- autotune_chbox = gr.Checkbox(label=translations["save_autotune"], value=True, interactive=True)
86
- pitch_chbox = gr.Checkbox(label=translations["save_pitch"], value=True, interactive=True)
87
- index_strength_chbox = gr.Checkbox(label=translations["save_index_2"], value=True, interactive=True)
88
- resample_sr_chbox = gr.Checkbox(label=translations["save_resample"], value=True, interactive=True)
89
- filter_radius_chbox = gr.Checkbox(label=translations["save_filter"], value=True, interactive=True)
90
- volume_envelope_chbox = gr.Checkbox(label=translations["save_envelope"], value=True, interactive=True)
91
- protect_chbox = gr.Checkbox(label=translations["save_protect"], value=True, interactive=True)
92
- split_audio_chbox = gr.Checkbox(label=translations["save_split"], value=True, interactive=True)
93
- formant_shifting_chbox = gr.Checkbox(label=translations["formantshift"], value=True, interactive=True)
94
- with gr.Row():
95
- with gr.Column():
96
- name_to_save_file = gr.Textbox(label=translations["filename_to_save"])
97
- save_file_button = gr.Button(translations["export_file"])
98
- with gr.Row():
99
- upload_presets = gr.File(label=translations["upload_presets"], file_types=[".json"])
100
- with gr.Column():
101
- with gr.Row():
102
- split_audio = gr.Checkbox(label=translations["split_audio"], value=False, interactive=True)
103
- formant_shifting = gr.Checkbox(label=translations["formantshift"], value=False, interactive=True)
104
- f0_autotune_strength = gr.Slider(minimum=0, maximum=1, label=translations["autotune_rate"], info=translations["autotune_rate_info"], value=1, step=0.1, interactive=True, visible=autotune.value)
105
- resample_sr = gr.Slider(minimum=0, maximum=96000, label=translations["resample"], info=translations["resample_info"], value=0, step=1, interactive=True)
106
- filter_radius = gr.Slider(minimum=0, maximum=7, label=translations["filter_radius"], info=translations["filter_radius_info"], value=3, step=1, interactive=True)
107
- volume_envelope = gr.Slider(minimum=0, maximum=1, label=translations["volume_envelope"], info=translations["volume_envelope_info"], value=1, step=0.1, interactive=True)
108
- protect = gr.Slider(minimum=0, maximum=1, label=translations["protect"], info=translations["protect_info"], value=0.5, step=0.01, interactive=True)
109
- with gr.Row():
110
- formant_qfrency = gr.Slider(value=1.0, label=translations["formant_qfrency"], info=translations["formant_qfrency"], minimum=0.0, maximum=16.0, step=0.1, interactive=True, visible=False)
111
- formant_timbre = gr.Slider(value=1.0, label=translations["formant_timbre"], info=translations["formant_timbre"], minimum=0.0, maximum=16.0, step=0.1, interactive=True, visible=False)
112
- with gr.Row():
113
- convert_button = gr.Button(translations["convert_audio"], variant="primary")
114
- gr.Markdown(translations["output_convert"])
115
- with gr.Row():
116
- main_convert = gr.Audio(show_download_button=True, interactive=False, label=translations["main_convert"])
117
- backing_convert = gr.Audio(show_download_button=True, interactive=False, label=translations["convert_backing"], visible=convert_backing.value)
118
- main_backing = gr.Audio(show_download_button=True, interactive=False, label=translations["main_or_backing"], visible=convert_backing.value)
119
- with gr.Row():
120
- original_convert = gr.Audio(show_download_button=True, interactive=False, label=translations["convert_original"], visible=use_original.value)
121
- vocal_instrument = gr.Audio(show_download_button=True, interactive=False, label=translations["voice_or_instruments"], visible=merge_instrument.value)
122
- with gr.Row():
123
- upload_f0_file.upload(fn=lambda inp: shutil.move(inp.name, os.path.join("assets", "f0")), inputs=[upload_f0_file], outputs=[f0_file_dropdown])
124
- refesh_f0_file.click(fn=change_f0_choices, inputs=[], outputs=[f0_file_dropdown])
125
- unlock_full_method.change(fn=unlock_f0, inputs=[unlock_full_method], outputs=[method])
126
- with gr.Row():
127
- load_click.click(
128
- fn=load_presets,
129
- inputs=[
130
- presets_name,
131
- cleaner0,
132
- autotune,
133
- pitch,
134
- clean_strength0,
135
- index_strength,
136
- resample_sr,
137
- filter_radius,
138
- volume_envelope,
139
- protect,
140
- split_audio,
141
- f0_autotune_strength,
142
- formant_qfrency,
143
- formant_timbre
144
- ],
145
- outputs=[
146
- cleaner0,
147
- autotune,
148
- pitch,
149
- clean_strength0,
150
- index_strength,
151
- resample_sr,
152
- filter_radius,
153
- volume_envelope,
154
- protect,
155
- split_audio,
156
- f0_autotune_strength,
157
- formant_shifting,
158
- formant_qfrency,
159
- formant_timbre
160
- ]
161
- )
162
- refesh_click.click(fn=change_preset_choices, inputs=[], outputs=[presets_name])
163
- save_file_button.click(
164
- fn=save_presets,
165
- inputs=[
166
- name_to_save_file,
167
- cleaner0,
168
- autotune,
169
- pitch,
170
- clean_strength0,
171
- index_strength,
172
- resample_sr,
173
- filter_radius,
174
- volume_envelope,
175
- protect,
176
- split_audio,
177
- f0_autotune_strength,
178
- cleaner_chbox,
179
- autotune_chbox,
180
- pitch_chbox,
181
- index_strength_chbox,
182
- resample_sr_chbox,
183
- filter_radius_chbox,
184
- volume_envelope_chbox,
185
- protect_chbox,
186
- split_audio_chbox,
187
- formant_shifting_chbox,
188
- formant_shifting,
189
- formant_qfrency,
190
- formant_timbre
191
- ],
192
- outputs=[presets_name]
193
- )
194
- with gr.Row():
195
- upload_presets.upload(fn=lambda audio_in: shutil.move(audio_in.name, os.path.join("assets", "presets")), inputs=[upload_presets], outputs=[presets_name])
196
- autotune.change(fn=visible, inputs=[autotune], outputs=[f0_autotune_strength])
197
- use_audio.change(fn=lambda a: [visible(a), visible(a), visible(a), visible(a), visible(a), valueFalse_interactive(a), valueFalse_interactive(a), valueFalse_interactive(a), valueFalse_interactive(a), visible(not a), visible(not a), visible(not a), visible(not a)], inputs=[use_audio], outputs=[main_backing, use_original, convert_backing, not_merge_backing, merge_instrument, use_original, convert_backing, not_merge_backing, merge_instrument, input_audio0, output_audio, input0, play_audio])
198
- with gr.Row():
199
- convert_backing.change(fn=lambda a,b: [change_backing_choices(a, b), visible(a)], inputs=[convert_backing, not_merge_backing], outputs=[use_original, backing_convert])
200
- use_original.change(fn=lambda audio, original: [visible(original), visible(not original), visible(audio and not original), valueFalse_interactive(not original), valueFalse_interactive(not original)], inputs=[use_audio, use_original], outputs=[original_convert, main_convert, main_backing, convert_backing, not_merge_backing])
201
- cleaner0.change(fn=visible, inputs=[cleaner0], outputs=[clean_strength0])
202
- with gr.Row():
203
- merge_instrument.change(fn=visible, inputs=[merge_instrument], outputs=[vocal_instrument])
204
- not_merge_backing.change(fn=lambda audio, merge, cvb: [visible(audio and not merge), change_backing_choices(cvb, merge)], inputs=[use_audio, not_merge_backing, convert_backing], outputs=[main_backing, use_original])
205
- method.change(fn=lambda method, hybrid: [visible(method == "hybrid"), hoplength_show(method, hybrid)], inputs=[method, hybrid_method], outputs=[hybrid_method, hop_length])
206
- with gr.Row():
207
- hybrid_method.change(fn=hoplength_show, inputs=[method, hybrid_method], outputs=[hop_length])
208
- refesh.click(fn=change_models_choices, inputs=[], outputs=[model_pth, model_index])
209
- model_pth.change(fn=get_index, inputs=[model_pth], outputs=[model_index])
210
- with gr.Row():
211
- input0.upload(fn=lambda audio_in: shutil.move(audio_in.name, os.path.join("audios")), inputs=[input0], outputs=[input_audio0])
212
- input_audio0.change(fn=lambda audio: audio if os.path.isfile(audio) else None, inputs=[input_audio0], outputs=[play_audio])
213
- formant_shifting.change(fn=lambda a: [visible(a)]*2, inputs=[formant_shifting], outputs=[formant_qfrency, formant_timbre])
214
- with gr.Row():
215
- embedders.change(fn=lambda embedders: visible(embedders == "custom"), inputs=[embedders], outputs=[custom_embedders])
216
- refesh0.click(fn=change_audios_choices, inputs=[input_audio0], outputs=[input_audio0])
217
- model_index.change(fn=index_strength_show, inputs=[model_index], outputs=[index_strength])
218
- with gr.Row():
219
- audio_select.change(fn=lambda: visible(True), inputs=[], outputs=[convert_button_2])
220
- convert_button.click(fn=lambda: visible(False), inputs=[], outputs=[convert_button])
221
- convert_button_2.click(fn=lambda: [visible(False), visible(False)], inputs=[], outputs=[audio_select, convert_button_2])
222
- with gr.Row():
223
- convert_button.click(
224
- fn=convert_selection,
225
- inputs=[
226
- cleaner0,
227
- autotune,
228
- use_audio,
229
- use_original,
230
- convert_backing,
231
- not_merge_backing,
232
- merge_instrument,
233
- pitch,
234
- clean_strength0,
235
- model_pth,
236
- model_index,
237
- index_strength,
238
- input_audio0,
239
- output_audio,
240
- export_format,
241
- method,
242
- hybrid_method,
243
- hop_length,
244
- embedders,
245
- custom_embedders,
246
- resample_sr,
247
- filter_radius,
248
- volume_envelope,
249
- protect,
250
- split_audio,
251
- f0_autotune_strength,
252
- checkpointing,
253
- onnx_f0_mode,
254
- formant_shifting,
255
- formant_qfrency,
256
- formant_timbre,
257
- f0_file_dropdown,
258
- embed_mode
259
- ],
260
- outputs=[audio_select, main_convert, backing_convert, main_backing, original_convert, vocal_instrument, convert_button],
261
- api_name="convert_selection"
262
- )
263
- embed_mode.change(fn=visible_embedders, inputs=[embed_mode], outputs=[embedders])
264
- convert_button_2.click(
265
- fn=convert_audio,
266
- inputs=[
267
- cleaner0,
268
- autotune,
269
- use_audio,
270
- use_original,
271
- convert_backing,
272
- not_merge_backing,
273
- merge_instrument,
274
- pitch,
275
- clean_strength0,
276
- model_pth,
277
- model_index,
278
- index_strength,
279
- input_audio0,
280
- output_audio,
281
- export_format,
282
- method,
283
- hybrid_method,
284
- hop_length,
285
- embedders,
286
- custom_embedders,
287
- resample_sr,
288
- filter_radius,
289
- volume_envelope,
290
- protect,
291
- split_audio,
292
- f0_autotune_strength,
293
- audio_select,
294
- checkpointing,
295
- onnx_f0_mode,
296
- formant_shifting,
297
- formant_qfrency,
298
- formant_timbre,
299
- f0_file_dropdown,
300
- embed_mode
301
- ],
302
- outputs=[main_convert, backing_convert, main_backing, original_convert, vocal_instrument, convert_button],
303
- api_name="convert_audio"
304
- )
305
-
306
- # Text-to-Speech Conversion Tab
307
- with gr.TabItem(translations["convert_text"], visible=configs.get("tts_tab", True)):
308
- gr.Markdown(translations["convert_text_markdown"])
309
- with gr.Row():
310
- gr.Markdown(translations["convert_text_markdown_2"])
311
- with gr.Accordion(translations["model_accordion"], open=True):
312
- with gr.Row(equal_height=True):
313
- model_pth0 = gr.Dropdown(label=translations["model_name"], choices=model_name, value=model_name[0] if len(model_name) >= 1 else "", interactive=True, allow_custom_value=True)
314
- model_index0 = gr.Dropdown(label=translations["index_path"], choices=index_path, value=index_path[0] if len(index_path) >= 1 else "", interactive=True, allow_custom_value=True)
315
- refesh1 = gr.Button(translations["refesh"])
316
-
317
- with gr.Row():
318
- with gr.Column():
319
- with gr.Group():
320
- with gr.Row():
321
- use_txt = gr.Checkbox(label=translations["input_txt"], value=False, interactive=True)
322
- google_tts_check_box = gr.Checkbox(label=translations["googletts"], value=False, interactive=True)
323
- prompt = gr.Textbox(label=translations["text_to_speech"], value="", placeholder="Hello Words", lines=3)
324
- with gr.Column():
325
- speed = gr.Slider(label=translations["voice_speed"], info=translations["voice_speed_info"], minimum=-100, maximum=100, value=0, step=1)
326
- pitch0 = gr.Slider(minimum=-20, maximum=20, step=1, info=translations["pitch_info"], label=translations["pitch"], value=0, interactive=True)
327
- with gr.Row():
328
- tts_button = gr.Button(translations["tts_1"], variant="primary", scale=2)
329
- convert_button0 = gr.Button(translations["tts_2"], variant="secondary", scale=2)
330
- with gr.Row():
331
- with gr.Column():
332
- txt_input = gr.File(label=translations["drop_text"], file_types=[".txt", ".srt"], visible=use_txt.value)
333
- tts_voice = gr.Dropdown(label=translations["voice"], choices=edgetts, interactive=True, value="vi-VN-NamMinhNeural")
334
- tts_pitch = gr.Slider(minimum=-20, maximum=20, step=1, info=translations["pitch_info_2"], label=translations["pitch"], value=0, interactive=True)
335
- with gr.Column():
336
-
337
- with gr.Row():
338
- index_strength0 = gr.Slider(label=translations["index_strength"], info=translations["index_strength_info"], minimum=0, maximum=1, value=0.5, step=0.01, interactive=True, visible=model_index0.value != "")
339
- with gr.Accordion(translations["output_path"], open=False):
340
- export_format0 = gr.Radio(label=translations["export_format"], info=translations["export_info"], choices=["wav", "mp3", "flac", "ogg", "opus", "m4a", "mp4", "aac", "alac", "wma", "aiff", "webm", "ac3"], value="wav", interactive=True)
341
- output_audio0 = gr.Textbox(label=translations["output_tts"], value="audios/tts.wav", placeholder="audios/tts.wav", info=translations["tts_output"], interactive=True)
342
- output_audio1 = gr.Textbox(label=translations["output_tts_convert"], value="audios/tts-convert.wav", placeholder="audios/tts-convert.wav", info=translations["tts_output"], interactive=True)
343
- with gr.Accordion(translations["setting"], open=False):
344
- with gr.Accordion(translations["f0_method"], open=False):
345
- with gr.Group():
346
- with gr.Row():
347
- onnx_f0_mode1 = gr.Checkbox(label=translations["f0_onnx_mode"], info=translations["f0_onnx_mode_info"], value=False, interactive=True)
348
- unlock_full_method3 = gr.Checkbox(label=translations["f0_unlock"], info=translations["f0_unlock_info"], value=False, interactive=True)
349
- method0 = gr.Radio(label=translations["f0_method"], info=translations["f0_method_info"], choices=method_f0+["hybrid"], value="rmvpe", interactive=True)
350
- hybrid_method0 = gr.Dropdown(label=translations["f0_method_hybrid"], info=translations["f0_method_hybrid_info"], choices=["hybrid[pm+dio]", "hybrid[pm+crepe-tiny]", "hybrid[pm+crepe]", "hybrid[pm+fcpe]", "hybrid[pm+rmvpe]", "hybrid[pm+harvest]", "hybrid[pm+yin]", "hybrid[dio+crepe-tiny]", "hybrid[dio+crepe]", "hybrid[dio+fcpe]", "hybrid[dio+rmvpe]", "hybrid[dio+harvest]", "hybrid[dio+yin]", "hybrid[crepe-tiny+crepe]", "hybrid[crepe-tiny+fcpe]", "hybrid[crepe-tiny+rmvpe]", "hybrid[crepe-tiny+harvest]", "hybrid[crepe+fcpe]", "hybrid[crepe+rmvpe]", "hybrid[crepe+harvest]", "hybrid[crepe+yin]", "hybrid[fcpe+rmvpe]", "hybrid[fcpe+harvest]", "hybrid[fcpe+yin]", "hybrid[rmvpe+harvest]", "hybrid[rmvpe+yin]", "hybrid[harvest+yin]"], value="hybrid[pm+dio]", interactive=True, allow_custom_value=True, visible=method0.value == "hybrid")
351
- hop_length0 = gr.Slider(label="Hop length", info=translations["hop_length_info"], minimum=1, maximum=512, value=128, step=1, interactive=True, visible=False)
352
- with gr.Accordion(translations["f0_file"], open=False):
353
- upload_f0_file0 = gr.File(label=translations["upload_f0"], file_types=[".txt"])
354
- f0_file_dropdown0 = gr.Dropdown(label=translations["f0_file_2"], value="", choices=f0_file, allow_custom_value=True, interactive=True)
355
- refesh_f0_file0 = gr.Button(translations["refesh"])
356
- with gr.Accordion(translations["hubert_model"], open=False):
357
- embed_mode1 = gr.Radio(label=translations["embed_mode"], info=translations["embed_mode_info"], value="fairseq", choices=embedders_mode, interactive=True, visible=True)
358
- embedders0 = gr.Radio(label=translations["hubert_model"], info=translations["hubert_info"], choices=embedders_model, value="hubert_base", interactive=True)
359
- custom_embedders0 = gr.Textbox(label=translations["modelname"], info=translations["modelname_info"], value="", placeholder="hubert_base", interactive=True, visible=embedders0.value == "custom")
360
- with gr.Group():
361
- with gr.Row():
362
- formant_shifting1 = gr.Checkbox(label=translations["formantshift"], value=False, interactive=True)
363
- split_audio0 = gr.Checkbox(label=translations["split_audio"], value=False, interactive=True)
364
- cleaner1 = gr.Checkbox(label=translations["clear_audio"], value=False, interactive=True)
365
- autotune3 = gr.Checkbox(label=translations["autotune"], value=False, interactive=True)
366
- checkpointing0 = gr.Checkbox(label=translations["memory_efficient_training"], value=False, interactive=True)
367
- with gr.Column():
368
- f0_autotune_strength0 = gr.Slider(minimum=0, maximum=1, label=translations["autotune_rate"], info=translations["autotune_rate_info"], value=1, step=0.1, interactive=True, visible=autotune3.value)
369
- clean_strength1 = gr.Slider(label=translations["clean_strength"], info=translations["clean_strength_info"], minimum=0, maximum=1, value=0.5, step=0.1, interactive=True, visible=cleaner1.value)
370
- resample_sr0 = gr.Slider(minimum=0, maximum=96000, label=translations["resample"], info=translations["resample_info"], value=0, step=1, interactive=True)
371
- filter_radius0 = gr.Slider(minimum=0, maximum=7, label=translations["filter_radius"], info=translations["filter_radius_info"], value=3, step=1, interactive=True)
372
- volume_envelope0 = gr.Slider(minimum=0, maximum=1, label=translations["volume_envelope"], info=translations["volume_envelope_info"], value=1, step=0.1, interactive=True)
373
- protect0 = gr.Slider(minimum=0, maximum=1, label=translations["protect"], info=translations["protect_info"], value=0.5, step=0.01, interactive=True)
374
- with gr.Row():
375
- formant_qfrency1 = gr.Slider(value=1.0, label=translations["formant_qfrency"], info=translations["formant_qfrency"], minimum=0.0, maximum=16.0, step=0.1, interactive=True, visible=False)
376
- formant_timbre1 = gr.Slider(value=1.0, label=translations["formant_timbre"], info=translations["formant_timbre"], minimum=0.0, maximum=16.0, step=0.1, interactive=True, visible=False)
377
- with gr.Row():
378
- gr.Markdown(translations["output_tts_markdown"])
379
- with gr.Row():
380
- tts_voice_audio = gr.Audio(show_download_button=True, interactive=False, label=translations["output_text_to_speech"])
381
- tts_voice_convert = gr.Audio(show_download_button=True, interactive=False, label=translations["output_file_tts_convert"])
382
- with gr.Row():
383
- unlock_full_method3.change(fn=unlock_f0, inputs=[unlock_full_method3], outputs=[method0])
384
- upload_f0_file0.upload(fn=lambda inp: shutil.move(inp.name, os.path.join("assets", "f0")), inputs=[upload_f0_file0], outputs=[f0_file_dropdown0])
385
- refesh_f0_file0.click(fn=change_f0_choices, inputs=[], outputs=[f0_file_dropdown0])
386
- with gr.Row():
387
- embed_mode1.change(fn=visible_embedders, inputs=[embed_mode1], outputs=[embedders0])
388
- autotune3.change(fn=visible, inputs=[autotune3], outputs=[f0_autotune_strength0])
389
- model_pth0.change(fn=get_index, inputs=[model_pth0], outputs=[model_index0])
390
- with gr.Row():
391
- cleaner1.change(fn=visible, inputs=[cleaner1], outputs=[clean_strength1])
392
- method0.change(fn=lambda method, hybrid: [visible(method == "hybrid"), hoplength_show(method, hybrid)], inputs=[method0, hybrid_method0], outputs=[hybrid_method0, hop_length0])
393
- hybrid_method0.change(fn=hoplength_show, inputs=[method0, hybrid_method0], outputs=[hop_length0])
394
- with gr.Row():
395
- refesh1.click(fn=change_models_choices, inputs=[], outputs=[model_pth0, model_index0])
396
- embedders0.change(fn=lambda embedders: visible(embedders == "custom"), inputs=[embedders0], outputs=[custom_embedders0])
397
- formant_shifting1.change(fn=lambda a: [visible(a)]*2, inputs=[formant_shifting1], outputs=[formant_qfrency1, formant_timbre1])
398
- with gr.Row():
399
- model_index0.change(fn=index_strength_show, inputs=[model_index0], outputs=[index_strength0])
400
- txt_input.upload(fn=process_input, inputs=[txt_input], outputs=[prompt])
401
- use_txt.change(fn=visible, inputs=[use_txt], outputs=[txt_input])
402
- with gr.Row():
403
- google_tts_check_box.change(fn=change_tts_voice_choices, inputs=[google_tts_check_box], outputs=[tts_voice])
404
- tts_button.click(
405
- fn=TTS,
406
- inputs=[
407
- prompt,
408
- tts_voice,
409
- speed,
410
- output_audio0,
411
- tts_pitch,
412
- google_tts_check_box,
413
- txt_input
414
- ],
415
- outputs=[tts_voice_audio],
416
- api_name="text-to-speech"
417
- )
418
- convert_button0.click(
419
- fn=convert_tts,
420
- inputs=[
421
- cleaner1,
422
- autotune3,
423
- pitch0,
424
- clean_strength1,
425
- model_pth0,
426
- model_index0,
427
- index_strength0,
428
- output_audio0,
429
- output_audio1,
430
- export_format0,
431
- method0,
432
- hybrid_method0,
433
- hop_length0,
434
- embedders0,
435
- custom_embedders0,
436
- resample_sr0,
437
- filter_radius0,
438
- volume_envelope0,
439
- protect0,
440
- split_audio0,
441
- f0_autotune_strength0,
442
- checkpointing0,
443
- onnx_f0_mode1,
444
- formant_shifting1,
445
- formant_qfrency1,
446
- formant_timbre1,
447
- f0_file_dropdown0,
448
- embed_mode1
449
- ],
450
- outputs=[tts_voice_convert],
451
- api_name="convert_tts"
452
- )
453
-
454
- # Whisper Conversion Tab
455
- with gr.TabItem(translations["convert_with_whisper"], visible=configs.get("convert_with_whisper", True)):
456
- gr.Markdown(f"## {translations['convert_with_whisper']}")
457
- with gr.Row():
458
- gr.Markdown(translations["convert_with_whisper_info"])
459
- with gr.Row():
460
- with gr.Column():
461
- with gr.Accordion(translations["model_accordion"] + " 1", open=True):
462
- with gr.Row(equal_height=True):
463
- model_pth2 = gr.Dropdown(label=translations["model_name"], choices=model_name, value=model_name[0] if len(model_name) >= 1 else "", interactive=True, allow_custom_value=True)
464
- model_index2 = gr.Dropdown(label=translations["index_path"], choices=index_path, value=index_path[0] if len(index_path) >= 1 else "", interactive=True, allow_custom_value=True)
465
- refesh2 = gr.Button(translations["refesh"])
466
- with gr.Accordion(translations["model_accordion"] + " 2", open=True):
467
- with gr.Row(equal_height=True):
468
- model_pth3 = gr.Dropdown(label=translations["model_name"], choices=model_name, value=model_name[0] if len(model_name) >= 1 else "", interactive=True, allow_custom_value=True)
469
- model_index3 = gr.Dropdown(label=translations["index_path"], choices=index_path, value=index_path[0] if len(index_path) >= 1 else "", interactive=True, allow_custom_value=True)
470
- refesh3 = gr.Button(translations["refesh"])
471
- with gr.Group():
472
- with gr.Row():
473
- cleaner2 = gr.Checkbox(label=translations["clear_audio"], value=False, interactive=True)
474
- autotune2 = gr.Checkbox(label=translations["autotune"], value=False, interactive=True)
475
- checkpointing2 = gr.Checkbox(label=translations["memory_efficient_training"], value=False, interactive=True)
476
- formant_shifting2 = gr.Checkbox(label=translations["formantshift"], value=False, interactive=True)
477
- with gr.Row():
478
- num_spk = gr.Slider(minimum=2, maximum=8, step=1, info=translations["num_spk_info"], label=translations["num_spk"], value=2, interactive=True)
479
- with gr.Row():
480
- with gr.Column():
481
- convert_button3 = gr.Button(translations["convert_audio"], variant="primary")
482
- with gr.Row():
483
- with gr.Column():
484
- with gr.Row():
485
- pitch3 = gr.Slider(minimum=-20, maximum=20, step=1, info=translations["pitch_info"], label=translations["pitch"], value=0, interactive=True)
486
- index_strength2 = gr.Slider(label=translations["index_strength"], info=translations["index_strength_info"], minimum=0, maximum=1, value=0.5, step=0.01, interactive=True, visible=model_index2.value != "")
487
- with gr.Accordion(translations["input_output"], open=False):
488
- with gr.Column():
489
- export_format2 = gr.Radio(label=translations["export_format"], info=translations["export_info"], choices=["wav", "mp3", "flac", "ogg", "opus", "m4a", "mp4", "aac", "alac", "wma", "aiff", "webm", "ac3"], value="wav", interactive=True)
490
- input_audio1 = gr.Dropdown(label=translations["audio_path"], value="", choices=paths_for_files, info=translations["provide_audio"], allow_custom_value=True, interactive=True)
491
- output_audio2 = gr.Textbox(label=translations["output_path"], value="audios/output.wav", placeholder="audios/output.wav", info=translations["output_path_info"], interactive=True)
492
- with gr.Column():
493
- refesh4 = gr.Button(translations["refesh"])
494
- with gr.Row():
495
- input2 = gr.File(label=translations["drop_audio"], file_types=[".wav", ".mp3", ".flac", ".ogg", ".opus", ".m4a", ".mp4", ".aac", ".alac", ".wma", ".aiff", ".webm", ".ac3"])
496
- with gr.Column():
497
- with gr.Row():
498
- pitch4 = gr.Slider(minimum=-20, maximum=20, step=1, info=translations["pitch_info"], label=translations["pitch"], value=0, interactive=True)
499
- index_strength3 = gr.Slider(label=translations["index_strength"], info=translations["index_strength_info"], minimum=0, maximum=1, value=0.5, step=0.01, interactive=True, visible=model_index3.value != "")
500
- with gr.Accordion(translations["setting"], open=False):
501
- with gr.Row():
502
- model_size = gr.Radio(label=translations["model_size"], info=translations["model_size_info"], choices=["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v1", "large-v2", "large-v3", "large-v3-turbo"], value="medium", interactive=True)
503
- with gr.Accordion(translations["f0_method"], open=False):
504
- with gr.Group():
505
- with gr.Row():
506
- onnx_f0_mode4 = gr.Checkbox(label=translations["f0_onnx_mode"], info=translations["f0_onnx_mode_info"], value=False, interactive=True)
507
- unlock_full_method2 = gr.Checkbox(label=translations["f0_unlock"], info=translations["f0_unlock_info"], value=False, interactive=True)
508
- method3 = gr.Radio(label=translations["f0_method"], info=translations["f0_method_info"], choices=method_f0+["hybrid"], value="rmvpe", interactive=True)
509
- hybrid_method3 = gr.Dropdown(label=translations["f0_method_hybrid"], info=translations["f0_method_hybrid_info"], choices=["hybrid[pm+dio]", "hybrid[pm+crepe-tiny]", "hybrid[pm+crepe]", "hybrid[pm+fcpe]", "hybrid[pm+rmvpe]", "hybrid[pm+harvest]", "hybrid[pm+yin]", "hybrid[dio+crepe-tiny]", "hybrid[dio+crepe]", "hybrid[dio+fcpe]", "hybrid[dio+rmvpe]", "hybrid[dio+harvest]", "hybrid[dio+yin]", "hybrid[crepe-tiny+crepe]", "hybrid[crepe-tiny+fcpe]", "hybrid[crepe-tiny+rmvpe]", "hybrid[crepe-tiny+harvest]", "hybrid[crepe+fcpe]", "hybrid[crepe+rmvpe]", "hybrid[crepe+harvest]", "hybrid[crepe+yin]", "hybrid[fcpe+rmvpe]", "hybrid[fcpe+harvest]", "hybrid[fcpe+yin]", "hybrid[rmvpe+harvest]", "hybrid[rmvpe+yin]", "hybrid[harvest+yin]"], value="hybrid[pm+dio]", interactive=True, allow_custom_value=True, visible=method3.value == "hybrid")
510
- hop_length3 = gr.Slider(label="Hop length", info=translations["hop_length_info"], minimum=1, maximum=512, value=128, step=1, interactive=True, visible=False)
511
- with gr.Accordion(translations["hubert_model"], open=False):
512
- embed_mode3 = gr.Radio(label=translations["embed_mode"], info=translations["embed_mode_info"], value="fairseq", choices=embedders_mode, interactive=True, visible=True)
513
- embedders3 = gr.Radio(label=translations["hubert_model"], info=translations["hubert_info"], choices=embedders_model, value="hubert_base", interactive=True)
514
- custom_embedders3 = gr.Textbox(label=translations["modelname"], info=translations["modelname_info"], value="", placeholder="hubert_base", interactive=True, visible=embedders3.value == "custom")
515
- with gr.Column():
516
- clean_strength3 = gr.Slider(label=translations["clean_strength"], info=translations["clean_strength_info"], minimum=0, maximum=1, value=0.5, step=0.1, interactive=True, visible=cleaner2.value)
517
- f0_autotune_strength3 = gr.Slider(minimum=0, maximum=1, label=translations["autotune_rate"], info=translations["autotune_rate_info"], value=1, step=0.1, interactive=True, visible=autotune.value)
518
- resample_sr3 = gr.Slider(minimum=0, maximum=96000, label=translations["resample"], info=translations["resample_info"], value=0, step=1, interactive=True)
519
- filter_radius3 = gr.Slider(minimum=0, maximum=7, label=translations["filter_radius"], info=translations["filter_radius_info"], value=3, step=1, interactive=True)
520
- volume_envelope3 = gr.Slider(minimum=0, maximum=1, label=translations["volume_envelope"], info=translations["volume_envelope_info"], value=1, step=0.1, interactive=True)
521
- protect3 = gr.Slider(minimum=0, maximum=1, label=translations["protect"], info=translations["protect_info"], value=0.5, step=0.01, interactive=True)
522
- with gr.Row():
523
- formant_qfrency3 = gr.Slider(value=1.0, label=translations["formant_qfrency"] + " 1", info=translations["formant_qfrency"], minimum=0.0, maximum=16.0, step=0.1, interactive=True, visible=False)
524
- formant_timbre3 = gr.Slider(value=1.0, label=translations["formant_timbre"] + " 1", info=translations["formant_timbre"], minimum=0.0, maximum=16.0, step=0.1, interactive=True, visible=False)
525
- with gr.Row():
526
- formant_qfrency4 = gr.Slider(value=1.0, label=translations["formant_qfrency"] + " 2", info=translations["formant_qfrency"], minimum=0.0, maximum=16.0, step=0.1, interactive=True, visible=False)
527
- formant_timbre4 = gr.Slider(value=1.0, label=translations["formant_timbre"] + " 2", info=translations["formant_timbre"], minimum=0.0, maximum=16.0, step=0.1, interactive=True, visible=False)
528
- with gr.Row():
529
- gr.Markdown(translations["input_output"])
530
- with gr.Row():
531
- play_audio2 = gr.Audio(show_download_button=True, interactive=False, label=translations["input_audio"])
532
- play_audio3 = gr.Audio(show_download_button=True, interactive=False, label=translations["output_file_tts_convert"])
533
- with gr.Row():
534
- autotune2.change(fn=visible, inputs=[autotune2], outputs=[f0_autotune_strength3])
535
- cleaner2.change(fn=visible, inputs=[cleaner2], outputs=[clean_strength3])
536
- method3.change(fn=lambda method, hybrid: [visible(method == "hybrid"), hoplength_show(method, hybrid)], inputs=[method3, hybrid_method3], outputs=[hybrid_method3, hop_length3])
537
- with gr.Row():
538
- hybrid_method3.change(fn=hoplength_show, inputs=[method3, hybrid_method3], outputs=[hop_length3])
539
- refesh2.click(fn=change_models_choices, inputs=[], outputs=[model_pth2, model_index2])
540
- model_pth2.change(fn=get_index, inputs=[model_pth2], outputs=[model_index2])
541
- with gr.Row():
542
- refesh3.click(fn=change_models_choices, inputs=[], outputs=[model_pth3, model_index3])
543
- model_pth3.change(fn=get_index, inputs=[model_pth3], outputs=[model_index3])
544
- input2.upload(fn=lambda audio_in: shutil.move(audio_in.name, os.path.join("audios")), inputs=[input2], outputs=[input_audio1])
545
- with gr.Row():
546
- input_audio1.change(fn=lambda audio: audio if os.path.isfile(audio) else None, inputs=[input_audio1], outputs=[play_audio2])
547
- formant_shifting2.change(fn=lambda a: [visible(a)]*4, inputs=[formant_shifting2], outputs=[formant_qfrency3, formant_timbre3, formant_qfrency4, formant_timbre4])
548
- embedders3.change(fn=lambda embedders: visible(embedders == "custom"), inputs=[embedders3], outputs=[custom_embedders3])
549
- with gr.Row():
550
- refesh4.click(fn=change_audios_choices, inputs=[input_audio1], outputs=[input_audio1])
551
- model_index2.change(fn=index_strength_show, inputs=[model_index2], outputs=[index_strength2])
552
- model_index3.change(fn=index_strength_show, inputs=[model_index3], outputs=[index_strength3])
553
- with gr.Row():
554
- unlock_full_method2.change(fn=unlock_f0, inputs=[unlock_full_method2], outputs=[method3])
555
- embed_mode3.change(fn=visible_embedders, inputs=[embed_mode3], outputs=[embedders3])
556
- convert_button3.click(
557
- fn=convert_with_whisper,
558
- inputs=[
559
- num_spk,
560
- model_size,
561
- cleaner2,
562
- clean_strength3,
563
- autotune2,
564
- f0_autotune_strength3,
565
- checkpointing2,
566
- model_pth2,
567
- model_pth3,
568
- model_index2,
569
- model_index3,
570
- pitch3,
571
- pitch4,
572
- index_strength2,
573
- index_strength3,
574
- export_format2,
575
- input_audio1,
576
- output_audio2,
577
- onnx_f0_mode4,
578
- method3,
579
- hybrid_method3,
580
- hop_length3,
581
- embed_mode3,
582
- embedders3,
583
- custom_embedders3,
584
- resample_sr3,
585
- filter_radius3,
586
- volume_envelope3,
587
- protect3,
588
- formant_shifting2,
589
- formant_qfrency3,
590
- formant_timbre3,
591
- formant_qfrency4,
592
- formant_timbre4,
593
- ],
594
- outputs=[play_audio3],
595
- api_name="convert_with_whisper"
596
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/app/tabs/models/model.py DELETED
@@ -1,465 +0,0 @@
1
- from main.tools import huggingface
2
- from main.configs.config import Config
3
- from main.app.based.utils import *
4
- import gradio as gr
5
-
6
-
7
- def model_tabs():
8
- with gr.Tabs():
9
- with gr.Tab(label=translations["downloads"], visible=configs.get("downloads_tab", True)):
10
- gr.Markdown(translations["download_markdown"])
11
- with gr.Row():
12
- gr.Markdown(translations["download_markdown_2"])
13
- with gr.Row():
14
- with gr.Accordion(translations["model_download"], open=True):
15
- with gr.Row():
16
- downloadmodel = gr.Radio(label=translations["model_download_select"], choices=[translations["download_url"], translations["download_from_csv"], translations["search_models"], translations["upload"]], interactive=True, value=translations["download_url"])
17
- with gr.Row():
18
- gr.Markdown("___")
19
- with gr.Column():
20
- with gr.Row():
21
- url_input = gr.Textbox(label=translations["model_url"], value="", placeholder="https://...", scale=6)
22
- download_model_name = gr.Textbox(label=translations["modelname"], value="", placeholder=translations["modelname"], scale=2)
23
- url_download = gr.Button(value=translations["downloads"], scale=2)
24
- with gr.Column():
25
- model_browser = gr.Dropdown(choices=models.keys(), label=translations["model_warehouse"], scale=8, allow_custom_value=True, visible=False)
26
- download_from_browser = gr.Button(value=translations["get_model"], scale=2, variant="primary", visible=False)
27
- with gr.Column():
28
- search_name = gr.Textbox(label=translations["name_to_search"], placeholder=translations["modelname"], interactive=True, scale=8, visible=False)
29
- search = gr.Button(translations["search_2"], scale=2, visible=False)
30
- search_dropdown = gr.Dropdown(label=translations["select_download_model"], value="", choices=[], allow_custom_value=True, interactive=False, visible=False)
31
- download = gr.Button(translations["downloads"], variant="primary", visible=False)
32
- with gr.Column():
33
- model_upload = gr.File(label=translations["drop_model"], file_types=[".pth", ".onnx", ".index", ".zip"], visible=False)
34
- with gr.Row():
35
- with gr.Accordion(translations["download_pretrained_2"], open=False):
36
- with gr.Row():
37
- pretrain_download_choices = gr.Radio(label=translations["model_download_select"], choices=[translations["download_url"], translations["list_model"], translations["upload"]], value=translations["download_url"], interactive=True)
38
- with gr.Row():
39
- gr.Markdown("___")
40
- with gr.Column():
41
- with gr.Row():
42
- pretrainD = gr.Textbox(label=translations["pretrained_url"].format(dg="D"), value="", info=translations["only_huggingface"], placeholder="https://...", interactive=True, scale=4)
43
- pretrainG = gr.Textbox(label=translations["pretrained_url"].format(dg="G"), value="", info=translations["only_huggingface"], placeholder="https://...", interactive=True, scale=4)
44
- download_pretrain_button = gr.Button(translations["downloads"], scale=2)
45
- with gr.Column():
46
- with gr.Row():
47
- pretrain_choices = gr.Dropdown(label=translations["select_pretrain"], info=translations["select_pretrain_info"], choices=list(fetch_pretrained_data().keys()), value="Titan_Medium", allow_custom_value=True, interactive=True, scale=6, visible=False)
48
- sample_rate_pretrain = gr.Dropdown(label=translations["pretrain_sr"], info=translations["pretrain_sr"], choices=["48k", "40k", "32k"], value="48k", interactive=True, visible=False)
49
- download_pretrain_choices_button = gr.Button(translations["downloads"], scale=2, variant="primary", visible=False)
50
- with gr.Row():
51
- pretrain_upload_g = gr.File(label=translations["drop_pretrain"].format(dg="G"), file_types=[".pth"], visible=False)
52
- pretrain_upload_d = gr.File(label=translations["drop_pretrain"].format(dg="D"), file_types=[".pth"], visible=False)
53
- with gr.Row():
54
- url_download.click(
55
- fn=download_model,
56
- inputs=[
57
- url_input,
58
- download_model_name
59
- ],
60
- outputs=[url_input],
61
- api_name="download_model"
62
- )
63
- download_from_browser.click(
64
- fn=lambda model: download_model(models[model], model),
65
- inputs=[model_browser],
66
- outputs=[model_browser],
67
- api_name="download_browser"
68
- )
69
- with gr.Row():
70
- downloadmodel.change(fn=change_download_choices, inputs=[downloadmodel], outputs=[url_input, download_model_name, url_download, model_browser, download_from_browser, search_name, search, search_dropdown, download, model_upload])
71
- search.click(fn=search_models, inputs=[search_name], outputs=[search_dropdown, download])
72
- model_upload.upload(fn=save_drop_model, inputs=[model_upload], outputs=[model_upload])
73
- download.click(
74
- fn=lambda model: download_model(model_options[model], model),
75
- inputs=[search_dropdown],
76
- outputs=[search_dropdown],
77
- api_name="search_models"
78
- )
79
- with gr.Row():
80
- pretrain_download_choices.change(fn=change_download_pretrained_choices, inputs=[pretrain_download_choices], outputs=[pretrainD, pretrainG, download_pretrain_button, pretrain_choices, sample_rate_pretrain, download_pretrain_choices_button, pretrain_upload_d, pretrain_upload_g])
81
- pretrain_choices.change(fn=update_sample_rate_dropdown, inputs=[pretrain_choices], outputs=[sample_rate_pretrain])
82
- with gr.Row():
83
- download_pretrain_button.click(
84
- fn=download_pretrained_model,
85
- inputs=[
86
- pretrain_download_choices,
87
- pretrainD,
88
- pretrainG
89
- ],
90
- outputs=[pretrainD],
91
- api_name="download_pretrain_link"
92
- )
93
- download_pretrain_choices_button.click(
94
- fn=download_pretrained_model,
95
- inputs=[
96
- pretrain_download_choices,
97
- pretrain_choices,
98
- sample_rate_pretrain
99
- ],
100
- outputs=[pretrain_choices],
101
- api_name="download_pretrain_choices"
102
- )
103
- pretrain_upload_g.upload(
104
- fn=lambda pretrain_upload_g: shutil.move(pretrain_upload_g.name, os.path.join("assets", "models", "pretrained_custom")),
105
- inputs=[pretrain_upload_g],
106
- outputs=[],
107
- api_name="upload_pretrain_g"
108
- )
109
- pretrain_upload_d.upload(
110
- fn=lambda pretrain_upload_d: shutil.move(pretrain_upload_d.name, os.path.join("assets", "models", "pretrained_custom")),
111
- inputs=[pretrain_upload_d],
112
- outputs=[],
113
- api_name="upload_pretrain_d"
114
- )
115
-
116
- with gr.Tab(label=translations["createdataset"], visible=configs.get("create_dataset_tab", True)):
117
- gr.Markdown(translations["create_dataset_markdown"])
118
- with gr.Row():
119
- gr.Markdown(translations["create_dataset_markdown_2"])
120
- with gr.Row():
121
- dataset_url = gr.Textbox(label=translations["url_audio"], info=translations["create_dataset_url"], value="", placeholder="https://www.youtube.com/...", interactive=True)
122
- output_dataset = gr.Textbox(label=translations["output_data"], info=translations["output_data_info"], value="dataset", placeholder="dataset", interactive=True)
123
- with gr.Row():
124
- with gr.Column():
125
- with gr.Group():
126
- with gr.Row():
127
- separator_reverb = gr.Checkbox(label=translations["dereveb_audio"], value=False, interactive=True)
128
- denoise_mdx = gr.Checkbox(label=translations["denoise"], value=False, interactive=True)
129
- with gr.Row():
130
- kim_vocal_version = gr.Radio(label=translations["model_ver"], info=translations["model_ver_info"], choices=["Version-1", "Version-2"], value="Version-2", interactive=True)
131
- kim_vocal_overlap = gr.Radio(label=translations["overlap"], info=translations["overlap_info"], choices=["0.25", "0.5", "0.75", "0.99"], value="0.25", interactive=True)
132
- with gr.Row():
133
- kim_vocal_hop_length = gr.Slider(label="Hop length", info=translations["hop_length_info"], minimum=1, maximum=8192, value=1024, step=1, interactive=True)
134
- kim_vocal_batch_size = gr.Slider(label=translations["batch_size"], info=translations["mdx_batch_size_info"], minimum=1, maximum=64, value=1, step=1, interactive=True)
135
- with gr.Row():
136
- kim_vocal_segments_size = gr.Slider(label=translations["segments_size"], info=translations["segments_size_info"], minimum=32, maximum=3072, value=256, step=32, interactive=True)
137
- with gr.Row():
138
- sample_rate0 = gr.Slider(minimum=8000, maximum=96000, step=1, value=44100, label=translations["sr"], info=translations["sr_info"], interactive=True)
139
- with gr.Column():
140
- create_button = gr.Button(translations["createdataset"], variant="primary", scale=2, min_width=4000)
141
- with gr.Group():
142
- with gr.Row():
143
- clean_audio = gr.Checkbox(label=translations["clear_audio"], value=False, interactive=True)
144
- skip = gr.Checkbox(label=translations["skip"], value=False, interactive=True)
145
- with gr.Row():
146
- dataset_clean_strength = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5, label=translations["clean_strength"], info=translations["clean_strength_info"], interactive=True, visible=clean_audio.value)
147
- with gr.Row():
148
- skip_start = gr.Textbox(label=translations["skip_start"], info=translations["skip_start_info"], value="", placeholder="0,...", interactive=True, visible=skip.value)
149
- skip_end = gr.Textbox(label=translations["skip_end"], info=translations["skip_end_info"], value="", placeholder="0,...", interactive=True, visible=skip.value)
150
- create_dataset_info = gr.Textbox(label=translations["create_dataset_info"], value="", interactive=False)
151
- with gr.Row():
152
- clean_audio.change(fn=visible, inputs=[clean_audio], outputs=[dataset_clean_strength])
153
- skip.change(fn=lambda a: [valueEmpty_visible1(a)]*2, inputs=[skip], outputs=[skip_start, skip_end])
154
- with gr.Row():
155
- create_button.click(
156
- fn=create_dataset,
157
- inputs=[
158
- dataset_url,
159
- output_dataset,
160
- clean_audio,
161
- dataset_clean_strength,
162
- separator_reverb,
163
- kim_vocal_version,
164
- kim_vocal_overlap,
165
- kim_vocal_segments_size,
166
- denoise_mdx,
167
- skip,
168
- skip_start,
169
- skip_end,
170
- kim_vocal_hop_length,
171
- kim_vocal_batch_size,
172
- sample_rate0
173
- ],
174
- outputs=[create_dataset_info],
175
- api_name="create_dataset"
176
- )
177
-
178
- with gr.Tab(label=translations["training_model"], visible=configs.get("training_tab", True)):
179
- gr.Markdown(f"## {translations['training_model']}")
180
- with gr.Row():
181
- gr.Markdown(translations["training_markdown"])
182
- with gr.Row():
183
- with gr.Column():
184
- with gr.Row():
185
- with gr.Column():
186
- training_name = gr.Textbox(label=translations["modelname"], info=translations["training_model_name"], value="", placeholder=translations["modelname"], interactive=True)
187
- training_sr = gr.Radio(label=translations["sample_rate"], info=translations["sample_rate_info"], choices=["32k", "40k", "48k"], value="48k", interactive=True)
188
- training_ver = gr.Radio(label=translations["training_version"], info=translations["training_version_info"], choices=["v1", "v2"], value="v2", interactive=True)
189
- with gr.Row():
190
- clean_dataset = gr.Checkbox(label=translations["clear_dataset"], value=False, interactive=True)
191
- preprocess_cut = gr.Checkbox(label=translations["split_audio"], value=True, interactive=True)
192
- process_effects = gr.Checkbox(label=translations["preprocess_effect"], value=False, interactive=True)
193
- checkpointing1 = gr.Checkbox(label=translations["memory_efficient_training"], value=False, interactive=True)
194
- training_f0 = gr.Checkbox(label=translations["training_pitch"], value=True, interactive=True)
195
- upload = gr.Checkbox(label=translations["upload_dataset"], value=False, interactive=True)
196
- with gr.Row():
197
- clean_dataset_strength = gr.Slider(label=translations["clean_strength"], info=translations["clean_strength_info"], minimum=0, maximum=1, value=0.7, step=0.1, interactive=True, visible=clean_dataset.value)
198
- with gr.Column():
199
- preprocess_button = gr.Button(translations["preprocess_button"], scale=2)
200
- upload_dataset = gr.Files(label=translations["drop_audio"], file_types=[".wav", ".mp3", ".flac", ".ogg", ".opus", ".m4a", ".mp4", ".aac", ".alac", ".wma", ".aiff", ".webm", ".ac3"], visible=upload.value)
201
- preprocess_info = gr.Textbox(label=translations["preprocess_info"], value="", interactive=False)
202
- with gr.Column():
203
- with gr.Row():
204
- with gr.Column():
205
- with gr.Accordion(label=translations["f0_method"], open=False):
206
- with gr.Group():
207
- with gr.Row():
208
- onnx_f0_mode2 = gr.Checkbox(label=translations["f0_onnx_mode"], info=translations["f0_onnx_mode_info"], value=False, interactive=True)
209
- unlock_full_method4 = gr.Checkbox(label=translations["f0_unlock"], info=translations["f0_unlock_info"], value=False, interactive=True)
210
- extract_method = gr.Radio(label=translations["f0_method"], info=translations["f0_method_info"], choices=method_f0, value="rmvpe", interactive=True)
211
- extract_hop_length = gr.Slider(label="Hop length", info=translations["hop_length_info"], minimum=1, maximum=512, value=128, step=1, interactive=True, visible=False)
212
- with gr.Accordion(label=translations["hubert_model"], open=False):
213
- with gr.Group():
214
- embed_mode2 = gr.Radio(label=translations["embed_mode"], info=translations["embed_mode_info"], value="fairseq", choices=embedders_mode, interactive=True, visible=True)
215
- extract_embedders = gr.Radio(label=translations["hubert_model"], info=translations["hubert_info"], choices=embedders_model, value="hubert_base", interactive=True)
216
- with gr.Row():
217
- extract_embedders_custom = gr.Textbox(label=translations["modelname"], info=translations["modelname_info"], value="", placeholder="hubert_base", interactive=True, visible=extract_embedders.value == "custom")
218
- with gr.Column():
219
- extract_button = gr.Button(translations["extract_button"], scale=2)
220
- extract_info = gr.Textbox(label=translations["extract_info"], value="", interactive=False)
221
- with gr.Column():
222
- with gr.Row():
223
- with gr.Column():
224
- total_epochs = gr.Slider(label=translations["total_epoch"], info=translations["total_epoch_info"], minimum=1, maximum=10000, value=300, step=1, interactive=True)
225
- save_epochs = gr.Slider(label=translations["save_epoch"], info=translations["save_epoch_info"], minimum=1, maximum=10000, value=50, step=1, interactive=True)
226
- with gr.Column():
227
- with gr.Row():
228
- index_button = gr.Button(f"3. {translations['create_index']}", variant="primary", scale=2)
229
- training_button = gr.Button(f"4. {translations['training_model']}", variant="primary", scale=2)
230
- with gr.Row():
231
- with gr.Accordion(label=translations["setting"], open=False):
232
- with gr.Row():
233
- index_algorithm = gr.Radio(label=translations["index_algorithm"], info=translations["index_algorithm_info"], choices=["Auto", "Faiss", "KMeans"], value="Auto", interactive=True)
234
- with gr.Row():
235
- custom_dataset = gr.Checkbox(label=translations["custom_dataset"], info=translations["custom_dataset_info"], value=False, interactive=True)
236
- overtraining_detector = gr.Checkbox(label=translations["overtraining_detector"], info=translations["overtraining_detector_info"], value=False, interactive=True)
237
- clean_up = gr.Checkbox(label=translations["cleanup_training"], info=translations["cleanup_training_info"], value=False, interactive=True)
238
- cache_in_gpu = gr.Checkbox(label=translations["cache_in_gpu"], info=translations["cache_in_gpu_info"], value=False, interactive=True)
239
- with gr.Column():
240
- dataset_path = gr.Textbox(label=translations["dataset_folder"], value="dataset", interactive=True, visible=custom_dataset.value)
241
- with gr.Column():
242
- threshold = gr.Slider(minimum=1, maximum=100, value=50, step=1, label=translations["threshold"], interactive=True, visible=overtraining_detector.value)
243
- with gr.Accordion(translations["setting_cpu_gpu"], open=False):
244
- with gr.Column():
245
- gpu_number = gr.Textbox(label=translations["gpu_number"], value=str("-".join(map(str, range(torch.cuda.device_count()))) if torch.cuda.is_available() else "-"), info=translations["gpu_number_info"], interactive=True)
246
- gpu_info = gr.Textbox(label=translations["gpu_info"], value=get_gpu_info(), info=translations["gpu_info_2"], interactive=False)
247
- cpu_core = gr.Slider(label=translations["cpu_core"], info=translations["cpu_core_info"], minimum=0, maximum=cpu_count(), value=cpu_count(), step=1, interactive=True)
248
- train_batch_size = gr.Slider(label=translations["batch_size"], info=translations["batch_size_info"], minimum=1, maximum=64, value=8, step=1, interactive=True)
249
- with gr.Row():
250
- save_only_latest = gr.Checkbox(label=translations["save_only_latest"], info=translations["save_only_latest_info"], value=True, interactive=True)
251
- save_every_weights = gr.Checkbox(label=translations["save_every_weights"], info=translations["save_every_weights_info"], value=True, interactive=True)
252
- not_use_pretrain = gr.Checkbox(label=translations["not_use_pretrain_2"], info=translations["not_use_pretrain_info"], value=False, interactive=True)
253
- custom_pretrain = gr.Checkbox(label=translations["custom_pretrain"], info=translations["custom_pretrain_info"], value=False, interactive=True)
254
- with gr.Row():
255
- vocoders = gr.Radio(label=translations["vocoder"], info=translations["vocoder_info"], choices=["Default", "MRF-HiFi-GAN", "RefineGAN"], value="Default", interactive=True)
256
- with gr.Row():
257
- deterministic = gr.Checkbox(label=translations["deterministic"], info=translations["deterministic_info"], value=False, interactive=True)
258
- benchmark = gr.Checkbox(label=translations["benchmark"], info=translations["benchmark_info"], value=False, interactive=True)
259
- with gr.Row():
260
- model_author = gr.Textbox(label=translations["training_author"], info=translations["training_author_info"], value="", placeholder=translations["training_author"], interactive=True)
261
- with gr.Row():
262
- with gr.Column():
263
- with gr.Accordion(translations["custom_pretrain_info"], open=False, visible=custom_pretrain.value and not not_use_pretrain.value) as pretrain_setting:
264
- pretrained_D = gr.Dropdown(label=translations["pretrain_file"].format(dg="D"), choices=pretrainedD, value=pretrainedD[0] if len(pretrainedD) > 0 else '', interactive=True, allow_custom_value=True)
265
- pretrained_G = gr.Dropdown(label=translations["pretrain_file"].format(dg="G"), choices=pretrainedG, value=pretrainedG[0] if len(pretrainedG) > 0 else '', interactive=True, allow_custom_value=True)
266
- refesh_pretrain = gr.Button(translations["refesh"], scale=2)
267
- with gr.Row():
268
- training_info = gr.Textbox(label=translations["train_info"], value="", interactive=False)
269
- with gr.Row():
270
- with gr.Column():
271
- with gr.Accordion(translations["export_model"], open=False):
272
- with gr.Row():
273
- model_file= gr.Dropdown(label=translations["model_name"], choices=model_name, value=model_name[0] if len(model_name) >= 1 else "", interactive=True, allow_custom_value=True)
274
- index_file = gr.Dropdown(label=translations["index_path"], choices=index_path, value=index_path[0] if len(index_path) >= 1 else "", interactive=True, allow_custom_value=True)
275
- with gr.Row():
276
- refesh_file = gr.Button(f"1. {translations['refesh']}", scale=2)
277
- zip_model = gr.Button(translations["zip_model"], variant="primary", scale=2)
278
- with gr.Row():
279
- zip_output = gr.File(label=translations["output_zip"], file_types=[".zip"], interactive=False, visible=False)
280
- with gr.Row():
281
- vocoders.change(fn=pitch_guidance_lock, inputs=[vocoders], outputs=[training_f0])
282
- training_f0.change(fn=vocoders_lock, inputs=[training_f0, vocoders], outputs=[vocoders])
283
- unlock_full_method4.change(fn=unlock_f0, inputs=[unlock_full_method4], outputs=[extract_method])
284
- with gr.Row():
285
- refesh_file.click(fn=change_models_choices, inputs=[], outputs=[model_file, index_file])
286
- zip_model.click(fn=zip_file, inputs=[training_name, model_file, index_file], outputs=[zip_output])
287
- dataset_path.change(fn=lambda folder: os.makedirs(folder, exist_ok=True), inputs=[dataset_path], outputs=[])
288
- with gr.Row():
289
- upload.change(fn=visible, inputs=[upload], outputs=[upload_dataset])
290
- overtraining_detector.change(fn=visible, inputs=[overtraining_detector], outputs=[threshold])
291
- clean_dataset.change(fn=visible, inputs=[clean_dataset], outputs=[clean_dataset_strength])
292
- with gr.Row():
293
- custom_dataset.change(fn=lambda custom_dataset: [visible(custom_dataset), "dataset"],inputs=[custom_dataset], outputs=[dataset_path, dataset_path])
294
- training_ver.change(fn=unlock_vocoder, inputs=[training_ver, vocoders], outputs=[vocoders])
295
- vocoders.change(fn=unlock_ver, inputs=[training_ver, vocoders], outputs=[training_ver])
296
- upload_dataset.upload(
297
- fn=lambda files, folder: [shutil.move(f.name, os.path.join(folder, os.path.split(f.name)[1])) for f in files] if folder != "" else gr_warning(translations["dataset_folder1"]),
298
- inputs=[upload_dataset, dataset_path],
299
- outputs=[],
300
- api_name="upload_dataset"
301
- )
302
- with gr.Row():
303
- not_use_pretrain.change(fn=lambda a, b: visible(a and not b), inputs=[custom_pretrain, not_use_pretrain], outputs=[pretrain_setting])
304
- custom_pretrain.change(fn=lambda a, b: visible(a and not b), inputs=[custom_pretrain, not_use_pretrain], outputs=[pretrain_setting])
305
- refesh_pretrain.click(fn=change_pretrained_choices, inputs=[], outputs=[pretrained_D, pretrained_G])
306
- with gr.Row():
307
- preprocess_button.click(
308
- fn=preprocess,
309
- inputs=[
310
- training_name,
311
- training_sr,
312
- cpu_core,
313
- preprocess_cut,
314
- process_effects,
315
- dataset_path,
316
- clean_dataset,
317
- clean_dataset_strength
318
- ],
319
- outputs=[preprocess_info],
320
- api_name="preprocess"
321
- )
322
- with gr.Row():
323
- embed_mode2.change(fn=visible_embedders, inputs=[embed_mode2], outputs=[extract_embedders])
324
- extract_method.change(fn=hoplength_show, inputs=[extract_method], outputs=[extract_hop_length])
325
- extract_embedders.change(fn=lambda extract_embedders: visible(extract_embedders == "custom"), inputs=[extract_embedders], outputs=[extract_embedders_custom])
326
- with gr.Row():
327
- extract_button.click(
328
- fn=extract,
329
- inputs=[
330
- training_name,
331
- training_ver,
332
- extract_method,
333
- training_f0,
334
- extract_hop_length,
335
- cpu_core,
336
- gpu_number,
337
- training_sr,
338
- extract_embedders,
339
- extract_embedders_custom,
340
- onnx_f0_mode2,
341
- embed_mode2
342
- ],
343
- outputs=[extract_info],
344
- api_name="extract"
345
- )
346
- with gr.Row():
347
- index_button.click(
348
- fn=create_index,
349
- inputs=[
350
- training_name,
351
- training_ver,
352
- index_algorithm
353
- ],
354
- outputs=[training_info],
355
- api_name="create_index"
356
- )
357
- with gr.Row():
358
- training_button.click(
359
- fn=training,
360
- inputs=[
361
- training_name,
362
- training_ver,
363
- save_epochs,
364
- save_only_latest,
365
- save_every_weights,
366
- total_epochs,
367
- training_sr,
368
- train_batch_size,
369
- gpu_number,
370
- training_f0,
371
- not_use_pretrain,
372
- custom_pretrain,
373
- pretrained_G,
374
- pretrained_D,
375
- overtraining_detector,
376
- threshold,
377
- clean_up,
378
- cache_in_gpu,
379
- model_author,
380
- vocoders,
381
- checkpointing1,
382
- deterministic,
383
- benchmark
384
- ],
385
- outputs=[training_info],
386
- api_name="training_model"
387
- )
388
-
389
- with gr.Tab(label=translations["fushion"], visible=configs.get("fushion_tab", True)):
390
- gr.Markdown(translations["fushion_markdown"])
391
- with gr.Row():
392
- gr.Markdown(translations["fushion_markdown_2"])
393
- with gr.Row():
394
- name_to_save = gr.Textbox(label=translations["modelname"], placeholder="Model.pth", value="", max_lines=1, interactive=True)
395
- with gr.Row():
396
- fushion_button = gr.Button(translations["fushion"], variant="primary", scale=4)
397
- with gr.Column():
398
- with gr.Row():
399
- model_a = gr.File(label=f"{translations['model_name']} 1", file_types=[".pth", ".onnx"])
400
- model_b = gr.File(label=f"{translations['model_name']} 2", file_types=[".pth", ".onnx"])
401
- with gr.Row():
402
- model_path_a = gr.Textbox(label=f"{translations['model_path']} 1", value="", placeholder="assets/weights/Model_1.pth")
403
- model_path_b = gr.Textbox(label=f"{translations['model_path']} 2", value="", placeholder="assets/weights/Model_2.pth")
404
- with gr.Row():
405
- ratio = gr.Slider(minimum=0, maximum=1, label=translations["model_ratio"], info=translations["model_ratio_info"], value=0.5, interactive=True)
406
- with gr.Row():
407
- output_model = gr.File(label=translations["output_model_path"], file_types=[".pth", ".onnx"], interactive=False, visible=False)
408
- with gr.Row():
409
- model_a.upload(fn=lambda model: shutil.move(model.name, os.path.join("assets", "weights")), inputs=[model_a], outputs=[model_path_a])
410
- model_b.upload(fn=lambda model: shutil.move(model.name, os.path.join("assets", "weights")), inputs=[model_b], outputs=[model_path_b])
411
- with gr.Row():
412
- fushion_button.click(
413
- fn=fushion_model,
414
- inputs=[
415
- name_to_save,
416
- model_path_a,
417
- model_path_b,
418
- ratio
419
- ],
420
- outputs=[name_to_save, output_model],
421
- api_name="fushion_model"
422
- )
423
- fushion_button.click(fn=lambda: visible(True), inputs=[], outputs=[output_model])
424
-
425
- with gr.Tab(label=translations["read_model"], visible=configs.get("read_tab", True)):
426
- gr.Markdown(translations["read_model_markdown"])
427
- with gr.Row():
428
- gr.Markdown(translations["read_model_markdown_2"])
429
- with gr.Row():
430
- model = gr.File(label=translations["drop_model"], file_types=[".pth", ".onnx"])
431
- with gr.Row():
432
- read_button = gr.Button(translations["readmodel"], variant="primary", scale=2)
433
- with gr.Column():
434
- model_path = gr.Textbox(label=translations["model_path"], value="", placeholder="assets/weights/Model.pth", info=translations["model_path_info"], interactive=True)
435
- output_info = gr.Textbox(label=translations["modelinfo"], value="", interactive=False, scale=6)
436
- with gr.Row():
437
- model.upload(fn=lambda model: shutil.move(model.name, os.path.join("assets", "weights")), inputs=[model], outputs=[model_path])
438
- read_button.click(
439
- fn=model_info,
440
- inputs=[model_path],
441
- outputs=[output_info],
442
- api_name="read_model"
443
- )
444
-
445
- with gr.Tab(label=translations["convert_model"], visible=configs.get("onnx_tab", True)):
446
- gr.Markdown(translations["pytorch2onnx"])
447
- with gr.Row():
448
- gr.Markdown(translations["pytorch2onnx_markdown"])
449
- with gr.Row():
450
- model_pth_upload = gr.File(label=translations["drop_model"], file_types=[".pth"])
451
- with gr.Row():
452
- convert_onnx = gr.Button(translations["convert_model"], variant="primary", scale=2)
453
- with gr.Row():
454
- model_pth_path = gr.Textbox(label=translations["model_path"], value="", placeholder="assets/weights/Model.pth", info=translations["model_path_info"], interactive=True)
455
- with gr.Row():
456
- output_model2 = gr.File(label=translations["output_model_path"], file_types=[".pth", ".onnx"], interactive=False, visible=False)
457
- with gr.Row():
458
- model_pth_upload.upload(fn=lambda model_pth_upload: shutil.move(model_pth_upload.name, os.path.join("assets", "weights")), inputs=[model_pth_upload], outputs=[model_pth_path])
459
- convert_onnx.click(
460
- fn=onnx_export,
461
- inputs=[model_pth_path],
462
- outputs=[output_model2, output_info],
463
- api_name="model_onnx_export"
464
- )
465
- convert_onnx.click(fn=lambda: visible(True), inputs=[], outputs=[output_model2])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/app/tabs/utils/utils.py DELETED
@@ -1,305 +0,0 @@
1
- import gradio as gr
2
- from main.tools import huggingface
3
- from main.configs.config import Config
4
- from main.app.based.utils import *
5
-
6
- def utils_tabs():
7
- with gr.TabItem("utils"):
8
- with gr.Tabs():
9
- with gr.TabItem(translations["audio_editing"], visible=False):
10
- gr.Markdown(translations["audio_editing_info"])
11
- with gr.Row():
12
- gr.Markdown(translations["audio_editing_markdown"])
13
- with gr.Row():
14
- with gr.Column():
15
- with gr.Group():
16
- with gr.Row():
17
- save_compute = gr.Checkbox(label=translations["save_compute"], value=True, interactive=True)
18
- tar_prompt = gr.Textbox(label=translations["target_prompt"], info=translations["target_prompt_info"], placeholder="Piano and violin cover", lines=5, interactive=True)
19
- with gr.Column():
20
- cfg_scale_src = gr.Slider(value=3, minimum=0.5, maximum=25, label=translations["cfg_scale_src"], info=translations["cfg_scale_src_info"], interactive=True)
21
- cfg_scale_tar = gr.Slider(value=12, minimum=0.5, maximum=25, label=translations["cfg_scale_tar"], info=translations["cfg_scale_tar_info"], interactive=True)
22
- with gr.Row():
23
- edit_button = gr.Button(translations["editing"], variant="primary")
24
- with gr.Row():
25
- with gr.Column():
26
- drop_audio_file = gr.File(label=translations["drop_audio"], file_types=[".wav", ".mp3", ".flac", ".ogg", ".opus", ".m4a", ".mp4", ".aac", ".alac", ".wma", ".aiff", ".webm", ".ac3"])
27
- display_audio = gr.Audio(show_download_button=True, interactive=False, label=translations["input_audio"])
28
- with gr.Column():
29
- with gr.Accordion(translations["input_output"], open=False):
30
- with gr.Column():
31
- export_audio_format = gr.Radio(label=translations["export_format"], info=translations["export_info"], choices=["wav", "mp3", "flac", "ogg", "opus", "m4a", "mp4", "aac", "alac", "wma", "aiff", "webm", "ac3"], value="wav", interactive=True)
32
- input_audiopath = gr.Dropdown(label=translations["audio_path"], value="", choices=paths_for_files, info=translations["provide_audio"], allow_custom_value=True, interactive=True)
33
- output_audiopath = gr.Textbox(label=translations["output_path"], value="audios/output.wav", placeholder="audios/output.wav", info=translations["output_path_info"], interactive=True)
34
- with gr.Column():
35
- refesh_audio = gr.Button(translations["refesh"])
36
- with gr.Accordion(translations["setting"], open=False):
37
- audioldm2_model = gr.Radio(label=translations["audioldm2_model"], info=translations["audioldm2_model_info"], choices=["audioldm2", "audioldm2-large", "audioldm2-music"], value="audioldm2-music", interactive=True)
38
- with gr.Row():
39
- src_prompt = gr.Textbox(label=translations["source_prompt"], lines=2, interactive=True, info=translations["source_prompt_info"], placeholder="A recording of a happy upbeat classical music piece")
40
- with gr.Row():
41
- with gr.Column():
42
- audioldm2_sample_rate = gr.Slider(minimum=8000, maximum=96000, label=translations["sr"], info=translations["sr_info"], value=44100, step=1, interactive=True)
43
- t_start = gr.Slider(minimum=15, maximum=85, value=45, step=1, label=translations["t_start"], interactive=True, info=translations["t_start_info"])
44
- steps = gr.Slider(value=50, step=1, minimum=10, maximum=300, label=translations["steps_label"], info=translations["steps_info"], interactive=True)
45
- with gr.Row():
46
- gr.Markdown(translations["output_audio"])
47
- with gr.Row():
48
- output_audioldm2 = gr.Audio(show_download_button=True, interactive=False, label=translations["output_audio"])
49
- with gr.Row():
50
- refesh_audio.click(fn=change_audios_choices, inputs=[input_audiopath], outputs=[input_audiopath])
51
- drop_audio_file.upload(fn=lambda audio_in: shutil.move(audio_in.name, os.path.join("audios")), inputs=[drop_audio_file], outputs=[input_audiopath])
52
- input_audiopath.change(fn=lambda audio: audio if os.path.isfile(audio) else None, inputs=[input_audiopath], outputs=[display_audio])
53
- with gr.Row():
54
- edit_button.click(
55
- fn=run_audioldm2,
56
- inputs=[
57
- input_audiopath,
58
- output_audiopath,
59
- export_audio_format,
60
- audioldm2_sample_rate,
61
- audioldm2_model,
62
- src_prompt,
63
- tar_prompt,
64
- steps,
65
- cfg_scale_src,
66
- cfg_scale_tar,
67
- t_start,
68
- save_compute
69
- ],
70
- outputs=[output_audioldm2],
71
- api_name="audioldm2"
72
- )
73
-
74
- with gr.TabItem(translations["audio_effects"], visible=configs.get("effects_tab", True)):
75
- gr.Markdown(translations["apply_audio_effects"])
76
- with gr.Row():
77
- gr.Markdown(translations["audio_effects_edit"])
78
- with gr.Row():
79
- with gr.Column():
80
- with gr.Row():
81
- reverb_check_box = gr.Checkbox(label=translations["reverb"], value=False, interactive=True)
82
- chorus_check_box = gr.Checkbox(label=translations["chorus"], value=False, interactive=True)
83
- delay_check_box = gr.Checkbox(label=translations["delay"], value=False, interactive=True)
84
- phaser_check_box = gr.Checkbox(label=translations["phaser"], value=False, interactive=True)
85
- compressor_check_box = gr.Checkbox(label=translations["compressor"], value=False, interactive=True)
86
- more_options = gr.Checkbox(label=translations["more_option"], value=False, interactive=True)
87
- with gr.Row():
88
- with gr.Accordion(translations["input_output"], open=False):
89
- with gr.Row():
90
- upload_audio = gr.File(label=translations["drop_audio"], file_types=[".wav", ".mp3", ".flac", ".ogg", ".opus", ".m4a", ".mp4", ".aac", ".alac", ".wma", ".aiff", ".webm", ".ac3"])
91
- with gr.Row():
92
- audio_in_path = gr.Dropdown(label=translations["input_audio"], value="", choices=paths_for_files, info=translations["provide_audio"], interactive=True, allow_custom_value=True)
93
- audio_out_path = gr.Textbox(label=translations["output_audio"], value="audios/audio_effects.wav", placeholder="audios/audio_effects.wav", info=translations["provide_output"], interactive=True)
94
- with gr.Row():
95
- with gr.Column():
96
- audio_combination = gr.Checkbox(label=translations["merge_instruments"], value=False, interactive=True)
97
- audio_combination_input = gr.Dropdown(label=translations["input_audio"], value="", choices=paths_for_files, info=translations["provide_audio"], interactive=True, allow_custom_value=True, visible=audio_combination.value)
98
- with gr.Row():
99
- audio_effects_refesh = gr.Button(translations["refesh"])
100
- with gr.Row():
101
- audio_output_format = gr.Radio(label=translations["export_format"], info=translations["export_info"], choices=["wav", "mp3", "flac", "ogg", "opus", "m4a", "mp4", "aac", "alac", "wma", "aiff", "webm", "ac3"], value="wav", interactive=True)
102
- with gr.Row():
103
- apply_effects_button = gr.Button(translations["apply"], variant="primary", scale=2)
104
- with gr.Row():
105
- with gr.Column():
106
- with gr.Row():
107
- with gr.Accordion(translations["reverb"], open=False, visible=reverb_check_box.value) as reverb_accordion:
108
- reverb_freeze_mode = gr.Checkbox(label=translations["reverb_freeze"], info=translations["reverb_freeze_info"], value=False, interactive=True)
109
- reverb_room_size = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.15, label=translations["room_size"], info=translations["room_size_info"], interactive=True)
110
- reverb_damping = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.7, label=translations["damping"], info=translations["damping_info"], interactive=True)
111
- reverb_wet_level = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.2, label=translations["wet_level"], info=translations["wet_level_info"], interactive=True)
112
- reverb_dry_level = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.8, label=translations["dry_level"], info=translations["dry_level_info"], interactive=True)
113
- reverb_width = gr.Slider(minimum=0, maximum=1, step=0.01, value=1, label=translations["width"], info=translations["width_info"], interactive=True)
114
- with gr.Row():
115
- with gr.Accordion(translations["chorus"], open=False, visible=chorus_check_box.value) as chorus_accordion:
116
- chorus_depth = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label=translations["chorus_depth"], info=translations["chorus_depth_info"], interactive=True)
117
- chorus_rate_hz = gr.Slider(minimum=0.1, maximum=10, step=0.1, value=1.5, label=translations["chorus_rate_hz"], info=translations["chorus_rate_hz_info"], interactive=True)
118
- chorus_mix = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label=translations["chorus_mix"], info=translations["chorus_mix_info"], interactive=True)
119
- chorus_centre_delay_ms = gr.Slider(minimum=0, maximum=50, step=1, value=10, label=translations["chorus_centre_delay_ms"], info=translations["chorus_centre_delay_ms_info"], interactive=True)
120
- chorus_feedback = gr.Slider(minimum=-1, maximum=1, step=0.01, value=0, label=translations["chorus_feedback"], info=translations["chorus_feedback_info"], interactive=True)
121
- with gr.Row():
122
- with gr.Accordion(translations["delay"], open=False, visible=delay_check_box.value) as delay_accordion:
123
- delay_second = gr.Slider(minimum=0, maximum=5, step=0.01, value=0.5, label=translations["delay_seconds"], info=translations["delay_seconds_info"], interactive=True)
124
- delay_feedback = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label=translations["delay_feedback"], info=translations["delay_feedback_info"], interactive=True)
125
- delay_mix = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label=translations["delay_mix"], info=translations["delay_mix_info"], interactive=True)
126
- with gr.Column():
127
- with gr.Row():
128
- with gr.Accordion(translations["more_option"], open=False, visible=more_options.value) as more_accordion:
129
- with gr.Row():
130
- fade = gr.Checkbox(label=translations["fade"], value=False, interactive=True)
131
- bass_or_treble = gr.Checkbox(label=translations["bass_or_treble"], value=False, interactive=True)
132
- limiter = gr.Checkbox(label=translations["limiter"], value=False, interactive=True)
133
- resample_checkbox = gr.Checkbox(label=translations["resample"], value=False, interactive=True)
134
- with gr.Row():
135
- distortion_checkbox = gr.Checkbox(label=translations["distortion"], value=False, interactive=True)
136
- gain_checkbox = gr.Checkbox(label=translations["gain"], value=False, interactive=True)
137
- bitcrush_checkbox = gr.Checkbox(label=translations["bitcrush"], value=False, interactive=True)
138
- clipping_checkbox = gr.Checkbox(label=translations["clipping"], value=False, interactive=True)
139
- with gr.Accordion(translations["fade"], open=True, visible=fade.value) as fade_accordion:
140
- with gr.Row():
141
- fade_in = gr.Slider(minimum=0, maximum=10000, step=100, value=0, label=translations["fade_in"], info=translations["fade_in_info"], interactive=True)
142
- fade_out = gr.Slider(minimum=0, maximum=10000, step=100, value=0, label=translations["fade_out"], info=translations["fade_out_info"], interactive=True)
143
- with gr.Accordion(translations["bass_or_treble"], open=True, visible=bass_or_treble.value) as bass_treble_accordion:
144
- with gr.Row():
145
- bass_boost = gr.Slider(minimum=0, maximum=20, step=1, value=0, label=translations["bass_boost"], info=translations["bass_boost_info"], interactive=True)
146
- bass_frequency = gr.Slider(minimum=20, maximum=200, step=10, value=100, label=translations["bass_frequency"], info=translations["bass_frequency_info"], interactive=True)
147
- with gr.Row():
148
- treble_boost = gr.Slider(minimum=0, maximum=20, step=1, value=0, label=translations["treble_boost"], info=translations["treble_boost_info"], interactive=True)
149
- treble_frequency = gr.Slider(minimum=1000, maximum=10000, step=500, value=3000, label=translations["treble_frequency"], info=translations["treble_frequency_info"], interactive=True)
150
- with gr.Accordion(translations["limiter"], open=True, visible=limiter.value) as limiter_accordion:
151
- with gr.Row():
152
- limiter_threashold_db = gr.Slider(minimum=-60, maximum=0, step=1, value=-1, label=translations["limiter_threashold_db"], info=translations["limiter_threashold_db_info"], interactive=True)
153
- limiter_release_ms = gr.Slider(minimum=10, maximum=1000, step=1, value=100, label=translations["limiter_release_ms"], info=translations["limiter_release_ms_info"], interactive=True)
154
- with gr.Column():
155
- pitch_shift_semitones = gr.Slider(minimum=-20, maximum=20, step=1, value=0, label=translations["pitch"], info=translations["pitch_info"], interactive=True)
156
- audio_effect_resample_sr = gr.Slider(minimum=0, maximum=96000, step=1, value=0, label=translations["resample"], info=translations["resample_info"], interactive=True, visible=resample_checkbox.value)
157
- distortion_drive_db = gr.Slider(minimum=0, maximum=50, step=1, value=20, label=translations["distortion"], info=translations["distortion_info"], interactive=True, visible=distortion_checkbox.value)
158
- gain_db = gr.Slider(minimum=-60, maximum=60, step=1, value=0, label=translations["gain"], info=translations["gain_info"], interactive=True, visible=gain_checkbox.value)
159
- clipping_threashold_db = gr.Slider(minimum=-60, maximum=0, step=1, value=-1, label=translations["clipping_threashold_db"], info=translations["clipping_threashold_db_info"], interactive=True, visible=clipping_checkbox.value)
160
- bitcrush_bit_depth = gr.Slider(minimum=1, maximum=24, step=1, value=16, label=translations["bitcrush_bit_depth"], info=translations["bitcrush_bit_depth_info"], interactive=True, visible=bitcrush_checkbox.value)
161
- with gr.Row():
162
- with gr.Accordion(translations["phaser"], open=False, visible=phaser_check_box.value) as phaser_accordion:
163
- phaser_depth = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label=translations["phaser_depth"], info=translations["phaser_depth_info"], interactive=True)
164
- phaser_rate_hz = gr.Slider(minimum=0.1, maximum=10, step=0.1, value=1, label=translations["phaser_rate_hz"], info=translations["phaser_rate_hz_info"], interactive=True)
165
- phaser_mix = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label=translations["phaser_mix"], info=translations["phaser_mix_info"], interactive=True)
166
- phaser_centre_frequency_hz = gr.Slider(minimum=50, maximum=5000, step=10, value=1000, label=translations["phaser_centre_frequency_hz"], info=translations["phaser_centre_frequency_hz_info"], interactive=True)
167
- phaser_feedback = gr.Slider(minimum=-1, maximum=1, step=0.01, value=0, label=translations["phaser_feedback"], info=translations["phaser_feedback_info"], interactive=True)
168
- with gr.Row():
169
- with gr.Accordion(translations["compressor"], open=False, visible=compressor_check_box.value) as compressor_accordion:
170
- compressor_threashold_db = gr.Slider(minimum=-60, maximum=0, step=1, value=-20, label=translations["compressor_threashold_db"], info=translations["compressor_threashold_db_info"], interactive=True)
171
- compressor_ratio = gr.Slider(minimum=1, maximum=20, step=0.1, value=1, label=translations["compressor_ratio"], info=translations["compressor_ratio_info"], interactive=True)
172
- compressor_attack_ms = gr.Slider(minimum=0.1, maximum=100, step=0.1, value=10, label=translations["compressor_attack_ms"], info=translations["compressor_attack_ms_info"], interactive=True)
173
- compressor_release_ms = gr.Slider(minimum=10, maximum=1000, step=1, value=100, label=translations["compressor_release_ms"], info=translations["compressor_release_ms_info"], interactive=True)
174
- with gr.Row():
175
- gr.Markdown(translations["output_audio"])
176
- with gr.Row():
177
- audio_play_input = gr.Audio(show_download_button=True, interactive=False, label=translations["input_audio"])
178
- audio_play_output = gr.Audio(show_download_button=True, interactive=False, label=translations["output_audio"])
179
- with gr.Row():
180
- reverb_check_box.change(fn=visible, inputs=[reverb_check_box], outputs=[reverb_accordion])
181
- chorus_check_box.change(fn=visible, inputs=[chorus_check_box], outputs=[chorus_accordion])
182
- delay_check_box.change(fn=visible, inputs=[delay_check_box], outputs=[delay_accordion])
183
- with gr.Row():
184
- compressor_check_box.change(fn=visible, inputs=[compressor_check_box], outputs=[compressor_accordion])
185
- phaser_check_box.change(fn=visible, inputs=[phaser_check_box], outputs=[phaser_accordion])
186
- more_options.change(fn=visible, inputs=[more_options], outputs=[more_accordion])
187
- with gr.Row():
188
- fade.change(fn=visible, inputs=[fade], outputs=[fade_accordion])
189
- bass_or_treble.change(fn=visible, inputs=[bass_or_treble], outputs=[bass_treble_accordion])
190
- limiter.change(fn=visible, inputs=[limiter], outputs=[limiter_accordion])
191
- resample_checkbox.change(fn=visible, inputs=[resample_checkbox], outputs=[audio_effect_resample_sr])
192
- with gr.Row():
193
- distortion_checkbox.change(fn=visible, inputs=[distortion_checkbox], outputs=[distortion_drive_db])
194
- gain_checkbox.change(fn=visible, inputs=[gain_checkbox], outputs=[gain_db])
195
- clipping_checkbox.change(fn=visible, inputs=[clipping_checkbox], outputs=[clipping_threashold_db])
196
- bitcrush_checkbox.change(fn=visible, inputs=[bitcrush_checkbox], outputs=[bitcrush_bit_depth])
197
- with gr.Row():
198
- upload_audio.upload(fn=lambda audio_in: shutil.move(audio_in.name, os.path.join("audios")), inputs=[upload_audio], outputs=[audio_in_path])
199
- audio_in_path.change(fn=lambda audio: audio if audio else None, inputs=[audio_in_path], outputs=[audio_play_input])
200
- audio_effects_refesh.click(fn=lambda a, b: [change_audios_choices(a), change_audios_choices(b)], inputs=[audio_in_path, audio_combination_input], outputs=[audio_in_path, audio_combination_input])
201
- with gr.Row():
202
- more_options.change(fn=lambda: [False]*8, inputs=[], outputs=[fade, bass_or_treble, limiter, resample_checkbox, distortion_checkbox, gain_checkbox, clipping_checkbox, bitcrush_checkbox])
203
- audio_combination.change(fn=visible, inputs=[audio_combination], outputs=[audio_combination_input])
204
- with gr.Row():
205
- apply_effects_button.click(
206
- fn=audio_effects,
207
- inputs=[
208
- audio_in_path,
209
- audio_out_path,
210
- resample_checkbox,
211
- audio_effect_resample_sr,
212
- chorus_depth,
213
- chorus_rate_hz,
214
- chorus_mix,
215
- chorus_centre_delay_ms,
216
- chorus_feedback,
217
- distortion_drive_db,
218
- reverb_room_size,
219
- reverb_damping,
220
- reverb_wet_level,
221
- reverb_dry_level,
222
- reverb_width,
223
- reverb_freeze_mode,
224
- pitch_shift_semitones,
225
- delay_second,
226
- delay_feedback,
227
- delay_mix,
228
- compressor_threashold_db,
229
- compressor_ratio,
230
- compressor_attack_ms,
231
- compressor_release_ms,
232
- limiter_threashold_db,
233
- limiter_release_ms,
234
- gain_db,
235
- bitcrush_bit_depth,
236
- clipping_threashold_db,
237
- phaser_rate_hz,
238
- phaser_depth,
239
- phaser_centre_frequency_hz,
240
- phaser_feedback,
241
- phaser_mix,
242
- bass_boost,
243
- bass_frequency,
244
- treble_boost,
245
- treble_frequency,
246
- fade_in,
247
- fade_out,
248
- audio_output_format,
249
- chorus_check_box,
250
- distortion_checkbox,
251
- reverb_check_box,
252
- delay_check_box,
253
- compressor_check_box,
254
- limiter,
255
- gain_checkbox,
256
- bitcrush_checkbox,
257
- clipping_checkbox,
258
- phaser_check_box,
259
- bass_or_treble,
260
- fade,
261
- audio_combination,
262
- audio_combination_input
263
- ],
264
- outputs=[audio_play_output],
265
- api_name="audio_effects"
266
- )
267
-
268
- with gr.TabItem(translations["f0_extractor_tab"], visible=configs.get("f0_extractor_tab", True)):
269
- gr.Markdown(translations["f0_extractor_markdown"])
270
- with gr.Row():
271
- gr.Markdown(translations["f0_extractor_markdown_2"])
272
- with gr.Row():
273
- extractor_button = gr.Button(translations["extract_button"].replace("2. ", ""), variant="primary")
274
- with gr.Row():
275
- with gr.Column():
276
- upload_audio_file = gr.File(label=translations["drop_audio"], file_types=[".wav", ".mp3", ".flac", ".ogg", ".opus", ".m4a", ".mp4", ".aac", ".alac", ".wma", ".aiff", ".webm", ".ac3"])
277
- audioplay = gr.Audio(show_download_button=True, interactive=False, label=translations["input_audio"])
278
- with gr.Column():
279
- with gr.Accordion(translations["f0_method"], open=False):
280
- with gr.Group():
281
- onnx_f0_mode3 = gr.Checkbox(label=translations["f0_onnx_mode"], info=translations["f0_onnx_mode_info"], value=False, interactive=True)
282
- f0_method_extract = gr.Radio(label=translations["f0_method"], info=translations["f0_method_info"], choices=method_f0, value="rmvpe", interactive=True)
283
- with gr.Accordion(translations["audio_path"], open=True):
284
- input_audio_path = gr.Dropdown(label=translations["audio_path"], value="", choices=paths_for_files, allow_custom_value=True, interactive=True)
285
- refesh_audio_button = gr.Button(translations["refesh"])
286
- with gr.Row():
287
- gr.Markdown("___")
288
- with gr.Row():
289
- file_output = gr.File(label="", file_types=[".txt"], interactive=False)
290
- image_output = gr.Image(label="", interactive=False, show_download_button=True)
291
- with gr.Row():
292
- upload_audio_file.upload(fn=lambda audio_in: shutil.move(audio_in.name, os.path.join("audios")), inputs=[upload_audio_file], outputs=[input_audio_path])
293
- input_audio_path.change(fn=lambda audio: audio if os.path.isfile(audio) else None, inputs=[input_audio_path], outputs=[audioplay])
294
- refesh_audio_button.click(fn=change_audios_choices, inputs=[input_audio_path], outputs=[input_audio_path])
295
- with gr.Row():
296
- extractor_button.click(
297
- fn=f0_extract,
298
- inputs=[
299
- input_audio_path,
300
- f0_method_extract,
301
- onnx_f0_mode3
302
- ],
303
- outputs=[file_output, image_output],
304
- api_name="f0_extract"
305
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/app/tensorboard.py DELETED
@@ -1,30 +0,0 @@
1
- import os
2
- import sys
3
- import json
4
- import logging
5
- import webbrowser
6
-
7
- from tensorboard import program
8
-
9
- sys.path.append(os.getcwd())
10
-
11
- from main.configs.config import Config
12
- translations = Config().translations
13
-
14
- with open(os.path.join("main", "configs", "config.json"), "r") as f:
15
- configs = json.load(f)
16
-
17
- def launch_tensorboard():
18
- for l in ["root", "tensorboard"]:
19
- logging.getLogger(l).setLevel(logging.ERROR)
20
-
21
- tb = program.TensorBoard()
22
- tb.configure(argv=[None, "--logdir", "assets/logs", f"--port={configs['tensorboard_port']}"])
23
- url = tb.launch()
24
-
25
- print(f"{translations['tensorboard_url']}: {url}")
26
- if "--open" in sys.argv: webbrowser.open(url)
27
-
28
- return f"{translations['tensorboard_url']}: {url}"
29
-
30
- if __name__ == "__main__": launch_tensorboard()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/configs/config.json DELETED
@@ -1,549 +0,0 @@
1
- {
2
- "language": "en-US",
3
- "support_language": [
4
- "en-US",
5
- "id_Id",
6
- "ja-JP",
7
- "vi-VN"
8
- ],
9
- "theme": "gradio/default",
10
- "themes": [
11
- "NoCrypt/miku",
12
- "gstaff/xkcd",
13
- "JohnSmith9982/small_and_pretty",
14
- "ParityError/Interstellar",
15
- "earneleh/paris",
16
- "shivi/calm_seafoam",
17
- "Hev832/Applio",
18
- "YTheme/Minecraft",
19
- "gstaff/sketch",
20
- "SebastianBravo/simci_css",
21
- "allenai/gradio-theme",
22
- "Nymbo/Nymbo_Theme_5",
23
- "lone17/kotaemon",
24
- "Zarkel/IBM_Carbon_Theme",
25
- "SherlockRamos/Feliz",
26
- "freddyaboulton/dracula_revamped",
27
- "freddyaboulton/bad-theme-space",
28
- "gradio/dracula_revamped",
29
- "abidlabs/dracula_revamped",
30
- "gradio/dracula_test",
31
- "gradio/seafoam",
32
- "gradio/glass",
33
- "gradio/monochrome",
34
- "gradio/soft",
35
- "gradio/default",
36
- "gradio/base",
37
- "abidlabs/pakistan",
38
- "dawood/microsoft_windows",
39
- "ysharma/steampunk",
40
- "ysharma/huggingface",
41
- "abidlabs/Lime",
42
- "freddyaboulton/this-theme-does-not-exist-2",
43
- "aliabid94/new-theme",
44
- "aliabid94/test2",
45
- "aliabid94/test3",
46
- "aliabid94/test4",
47
- "abidlabs/banana",
48
- "freddyaboulton/test-blue",
49
- "gstaff/whiteboard",
50
- "ysharma/llamas",
51
- "abidlabs/font-test",
52
- "YenLai/Superhuman",
53
- "bethecloud/storj_theme",
54
- "sudeepshouche/minimalist",
55
- "knotdgaf/gradiotest",
56
- "ParityError/Anime",
57
- "Ajaxon6255/Emerald_Isle",
58
- "ParityError/LimeFace",
59
- "finlaymacklon/smooth_slate",
60
- "finlaymacklon/boxy_violet",
61
- "derekzen/stardust",
62
- "EveryPizza/Cartoony-Gradio-Theme",
63
- "Ifeanyi/Cyanister",
64
- "Tshackelton/IBMPlex-DenseReadable",
65
- "snehilsanyal/scikit-learn",
66
- "Himhimhim/xkcd",
67
- "nota-ai/theme",
68
- "rawrsor1/Everforest",
69
- "rottenlittlecreature/Moon_Goblin",
70
- "abidlabs/test-yellow",
71
- "abidlabs/test-yellow3",
72
- "idspicQstitho/dracula_revamped",
73
- "kfahn/AnimalPose",
74
- "HaleyCH/HaleyCH_Theme",
75
- "simulKitke/dracula_test",
76
- "braintacles/CrimsonNight",
77
- "wentaohe/whiteboardv2",
78
- "reilnuud/polite",
79
- "remilia/Ghostly",
80
- "Franklisi/darkmode",
81
- "coding-alt/soft",
82
- "xiaobaiyuan/theme_land",
83
- "step-3-profit/Midnight-Deep",
84
- "xiaobaiyuan/theme_demo",
85
- "Taithrah/Minimal",
86
- "Insuz/SimpleIndigo",
87
- "zkunn/Alipay_Gradio_theme",
88
- "Insuz/Mocha",
89
- "xiaobaiyuan/theme_brief",
90
- "Ama434/434-base-Barlow",
91
- "Ama434/def_barlow",
92
- "Ama434/neutral-barlow",
93
- "dawood/dracula_test",
94
- "nuttea/Softblue",
95
- "BlueDancer/Alien_Diffusion",
96
- "naughtondale/monochrome",
97
- "Dagfinn1962/standard",
98
- "default"
99
- ],
100
- "mdx_model": [
101
- "Main_340",
102
- "Main_390",
103
- "Main_406",
104
- "Main_427",
105
- "Main_438",
106
- "Inst_full_292",
107
- "Inst_HQ_1",
108
- "Inst_HQ_2",
109
- "Inst_HQ_3",
110
- "Inst_HQ_4",
111
- "Inst_HQ_5",
112
- "Kim_Vocal_1",
113
- "Kim_Vocal_2",
114
- "Kim_Inst",
115
- "Inst_187_beta",
116
- "Inst_82_beta",
117
- "Inst_90_beta",
118
- "Voc_FT",
119
- "Crowd_HQ",
120
- "Inst_1",
121
- "Inst_2",
122
- "Inst_3",
123
- "MDXNET_1_9703",
124
- "MDXNET_2_9682",
125
- "MDXNET_3_9662",
126
- "Inst_Main",
127
- "MDXNET_Main",
128
- "MDXNET_9482"
129
- ],
130
- "demucs_model": [
131
- "HT-Normal",
132
- "HT-Tuned",
133
- "HD_MMI",
134
- "HT_6S"
135
- ],
136
- "edge_tts": [
137
- "af-ZA-AdriNeural",
138
- "af-ZA-WillemNeural",
139
- "sq-AL-AnilaNeural",
140
- "sq-AL-IlirNeural",
141
- "am-ET-AmehaNeural",
142
- "am-ET-MekdesNeural",
143
- "ar-DZ-AminaNeural",
144
- "ar-DZ-IsmaelNeural",
145
- "ar-BH-AliNeural",
146
- "ar-BH-LailaNeural",
147
- "ar-EG-SalmaNeural",
148
- "ar-EG-ShakirNeural",
149
- "ar-IQ-BasselNeural",
150
- "ar-IQ-RanaNeural",
151
- "ar-JO-SanaNeural",
152
- "ar-JO-TaimNeural",
153
- "ar-KW-FahedNeural",
154
- "ar-KW-NouraNeural",
155
- "ar-LB-LaylaNeural",
156
- "ar-LB-RamiNeural",
157
- "ar-LY-ImanNeural",
158
- "ar-LY-OmarNeural",
159
- "ar-MA-JamalNeural",
160
- "ar-MA-MounaNeural",
161
- "ar-OM-AbdullahNeural",
162
- "ar-OM-AyshaNeural",
163
- "ar-QA-AmalNeural",
164
- "ar-QA-MoazNeural",
165
- "ar-SA-HamedNeural",
166
- "ar-SA-ZariyahNeural",
167
- "ar-SY-AmanyNeural",
168
- "ar-SY-LaithNeural",
169
- "ar-TN-HediNeural",
170
- "ar-TN-ReemNeural",
171
- "ar-AE-FatimaNeural",
172
- "ar-AE-HamdanNeural",
173
- "ar-YE-MaryamNeural",
174
- "ar-YE-SalehNeural",
175
- "az-AZ-BabekNeural",
176
- "az-AZ-BanuNeural",
177
- "bn-BD-NabanitaNeural",
178
- "bn-BD-PradeepNeural",
179
- "bn-IN-BashkarNeural",
180
- "bn-IN-TanishaaNeural",
181
- "bs-BA-GoranNeural",
182
- "bs-BA-VesnaNeural",
183
- "bg-BG-BorislavNeural",
184
- "bg-BG-KalinaNeural",
185
- "my-MM-NilarNeural",
186
- "my-MM-ThihaNeural",
187
- "ca-ES-EnricNeural",
188
- "ca-ES-JoanaNeural",
189
- "zh-HK-HiuGaaiNeural",
190
- "zh-HK-HiuMaanNeural",
191
- "zh-HK-WanLungNeural",
192
- "zh-CN-XiaoxiaoNeural",
193
- "zh-CN-XiaoyiNeural",
194
- "zh-CN-YunjianNeural",
195
- "zh-CN-YunxiNeural",
196
- "zh-CN-YunxiaNeural",
197
- "zh-CN-YunyangNeural",
198
- "zh-CN-liaoning-XiaobeiNeural",
199
- "zh-TW-HsiaoChenNeural",
200
- "zh-TW-YunJheNeural",
201
- "zh-TW-HsiaoYuNeural",
202
- "zh-CN-shaanxi-XiaoniNeural",
203
- "hr-HR-GabrijelaNeural",
204
- "hr-HR-SreckoNeural",
205
- "cs-CZ-AntoninNeural",
206
- "cs-CZ-VlastaNeural",
207
- "da-DK-ChristelNeural",
208
- "da-DK-JeppeNeural",
209
- "nl-BE-ArnaudNeural",
210
- "nl-BE-DenaNeural",
211
- "nl-NL-ColetteNeural",
212
- "nl-NL-FennaNeural",
213
- "nl-NL-MaartenNeural",
214
- "en-AU-NatashaNeural",
215
- "en-AU-WilliamNeural",
216
- "en-CA-ClaraNeural",
217
- "en-CA-LiamNeural",
218
- "en-HK-SamNeural",
219
- "en-HK-YanNeural",
220
- "en-IN-NeerjaExpressiveNeural",
221
- "en-IN-NeerjaNeural",
222
- "en-IN-PrabhatNeural",
223
- "en-IE-ConnorNeural",
224
- "en-IE-EmilyNeural",
225
- "en-KE-AsiliaNeural",
226
- "en-KE-ChilembaNeural",
227
- "en-NZ-MitchellNeural",
228
- "en-NZ-MollyNeural",
229
- "en-NG-AbeoNeural",
230
- "en-NG-EzinneNeural",
231
- "en-PH-JamesNeural",
232
- "en-PH-RosaNeural",
233
- "en-SG-LunaNeural",
234
- "en-SG-WayneNeural",
235
- "en-ZA-LeahNeural",
236
- "en-ZA-LukeNeural",
237
- "en-TZ-ElimuNeural",
238
- "en-TZ-ImaniNeural",
239
- "en-GB-LibbyNeural",
240
- "en-GB-MaisieNeural",
241
- "en-GB-RyanNeural",
242
- "en-GB-SoniaNeural",
243
- "en-GB-ThomasNeural",
244
- "en-US-AvaMultilingualNeural",
245
- "en-US-AndrewMultilingualNeural",
246
- "en-US-EmmaMultilingualNeural",
247
- "en-US-BrianMultilingualNeural",
248
- "en-US-AvaNeural",
249
- "en-US-AndrewNeural",
250
- "en-US-EmmaNeural",
251
- "en-US-BrianNeural",
252
- "en-US-AnaNeural",
253
- "en-US-AriaNeural",
254
- "en-US-ChristopherNeural",
255
- "en-US-EricNeural",
256
- "en-US-GuyNeural",
257
- "en-US-JennyNeural",
258
- "en-US-MichelleNeural",
259
- "en-US-RogerNeural",
260
- "en-US-SteffanNeural",
261
- "et-EE-AnuNeural",
262
- "et-EE-KertNeural",
263
- "fil-PH-AngeloNeural",
264
- "fil-PH-BlessicaNeural",
265
- "fi-FI-HarriNeural",
266
- "fi-FI-NooraNeural",
267
- "fr-BE-CharlineNeural",
268
- "fr-BE-GerardNeural",
269
- "fr-CA-ThierryNeural",
270
- "fr-CA-AntoineNeural",
271
- "fr-CA-JeanNeural",
272
- "fr-CA-SylvieNeural",
273
- "fr-FR-VivienneMultilingualNeural",
274
- "fr-FR-RemyMultilingualNeural",
275
- "fr-FR-DeniseNeural",
276
- "fr-FR-EloiseNeural",
277
- "fr-FR-HenriNeural",
278
- "fr-CH-ArianeNeural",
279
- "fr-CH-FabriceNeural",
280
- "gl-ES-RoiNeural",
281
- "gl-ES-SabelaNeural",
282
- "ka-GE-EkaNeural",
283
- "ka-GE-GiorgiNeural",
284
- "de-AT-IngridNeural",
285
- "de-AT-JonasNeural",
286
- "de-DE-SeraphinaMultilingualNeural",
287
- "de-DE-FlorianMultilingualNeural",
288
- "de-DE-AmalaNeural",
289
- "de-DE-ConradNeural",
290
- "de-DE-KatjaNeural",
291
- "de-DE-KillianNeural",
292
- "de-CH-JanNeural",
293
- "de-CH-LeniNeural",
294
- "el-GR-AthinaNeural",
295
- "el-GR-NestorasNeural",
296
- "gu-IN-DhwaniNeural",
297
- "gu-IN-NiranjanNeural",
298
- "he-IL-AvriNeural",
299
- "he-IL-HilaNeural",
300
- "hi-IN-MadhurNeural",
301
- "hi-IN-SwaraNeural",
302
- "hu-HU-NoemiNeural",
303
- "hu-HU-TamasNeural",
304
- "is-IS-GudrunNeural",
305
- "is-IS-GunnarNeural",
306
- "id-ID-ArdiNeural",
307
- "id-ID-GadisNeural",
308
- "ga-IE-ColmNeural",
309
- "ga-IE-OrlaNeural",
310
- "it-IT-GiuseppeNeural",
311
- "it-IT-DiegoNeural",
312
- "it-IT-ElsaNeural",
313
- "it-IT-IsabellaNeural",
314
- "ja-JP-KeitaNeural",
315
- "ja-JP-NanamiNeural",
316
- "jv-ID-DimasNeural",
317
- "jv-ID-SitiNeural",
318
- "kn-IN-GaganNeural",
319
- "kn-IN-SapnaNeural",
320
- "kk-KZ-AigulNeural",
321
- "kk-KZ-DauletNeural",
322
- "km-KH-PisethNeural",
323
- "km-KH-SreymomNeural",
324
- "ko-KR-HyunsuNeural",
325
- "ko-KR-InJoonNeural",
326
- "ko-KR-SunHiNeural",
327
- "lo-LA-ChanthavongNeural",
328
- "lo-LA-KeomanyNeural",
329
- "lv-LV-EveritaNeural",
330
- "lv-LV-NilsNeural",
331
- "lt-LT-LeonasNeural",
332
- "lt-LT-OnaNeural",
333
- "mk-MK-AleksandarNeural",
334
- "mk-MK-MarijaNeural",
335
- "ms-MY-OsmanNeural",
336
- "ms-MY-YasminNeural",
337
- "ml-IN-MidhunNeural",
338
- "ml-IN-SobhanaNeural",
339
- "mt-MT-GraceNeural",
340
- "mt-MT-JosephNeural",
341
- "mr-IN-AarohiNeural",
342
- "mr-IN-ManoharNeural",
343
- "mn-MN-BataaNeural",
344
- "mn-MN-YesuiNeural",
345
- "ne-NP-HemkalaNeural",
346
- "ne-NP-SagarNeural",
347
- "nb-NO-FinnNeural",
348
- "nb-NO-PernilleNeural",
349
- "ps-AF-GulNawazNeural",
350
- "ps-AF-LatifaNeural",
351
- "fa-IR-DilaraNeural",
352
- "fa-IR-FaridNeural",
353
- "pl-PL-MarekNeural",
354
- "pl-PL-ZofiaNeural",
355
- "pt-BR-ThalitaNeural",
356
- "pt-BR-AntonioNeural",
357
- "pt-BR-FranciscaNeural",
358
- "pt-PT-DuarteNeural",
359
- "pt-PT-RaquelNeural",
360
- "ro-RO-AlinaNeural",
361
- "ro-RO-EmilNeural",
362
- "ru-RU-DmitryNeural",
363
- "ru-RU-SvetlanaNeural",
364
- "sr-RS-NicholasNeural",
365
- "sr-RS-SophieNeural",
366
- "si-LK-SameeraNeural",
367
- "si-LK-ThiliniNeural",
368
- "sk-SK-LukasNeural",
369
- "sk-SK-ViktoriaNeural",
370
- "sl-SI-PetraNeural",
371
- "sl-SI-RokNeural",
372
- "so-SO-MuuseNeural",
373
- "so-SO-UbaxNeural",
374
- "es-AR-ElenaNeural",
375
- "es-AR-TomasNeural",
376
- "es-BO-MarceloNeural",
377
- "es-BO-SofiaNeural",
378
- "es-CL-CatalinaNeural",
379
- "es-CL-LorenzoNeural",
380
- "es-ES-XimenaNeural",
381
- "es-CO-GonzaloNeural",
382
- "es-CO-SalomeNeural",
383
- "es-CR-JuanNeural",
384
- "es-CR-MariaNeural",
385
- "es-CU-BelkysNeural",
386
- "es-CU-ManuelNeural",
387
- "es-DO-EmilioNeural",
388
- "es-DO-RamonaNeural",
389
- "es-EC-AndreaNeural",
390
- "es-EC-LuisNeural",
391
- "es-SV-LorenaNeural",
392
- "es-SV-RodrigoNeural",
393
- "es-GQ-JavierNeural",
394
- "es-GQ-TeresaNeural",
395
- "es-GT-AndresNeural",
396
- "es-GT-MartaNeural",
397
- "es-HN-CarlosNeural",
398
- "es-HN-KarlaNeural",
399
- "es-MX-DaliaNeural",
400
- "es-MX-JorgeNeural",
401
- "es-NI-FedericoNeural",
402
- "es-NI-YolandaNeural",
403
- "es-PA-MargaritaNeural",
404
- "es-PA-RobertoNeural",
405
- "es-PY-MarioNeural",
406
- "es-PY-TaniaNeural",
407
- "es-PE-AlexNeural",
408
- "es-PE-CamilaNeural",
409
- "es-PR-KarinaNeural",
410
- "es-PR-VictorNeural",
411
- "es-ES-AlvaroNeural",
412
- "es-ES-ElviraNeural",
413
- "es-US-AlonsoNeural",
414
- "es-US-PalomaNeural",
415
- "es-UY-MateoNeural",
416
- "es-UY-ValentinaNeural",
417
- "es-VE-PaolaNeural",
418
- "es-VE-SebastianNeural",
419
- "su-ID-JajangNeural",
420
- "su-ID-TutiNeural",
421
- "sw-KE-RafikiNeural",
422
- "sw-KE-ZuriNeural",
423
- "sw-TZ-DaudiNeural",
424
- "sw-TZ-RehemaNeural",
425
- "sv-SE-MattiasNeural",
426
- "sv-SE-SofieNeural",
427
- "ta-IN-PallaviNeural",
428
- "ta-IN-ValluvarNeural",
429
- "ta-MY-KaniNeural",
430
- "ta-MY-SuryaNeural",
431
- "ta-SG-AnbuNeural",
432
- "ta-SG-VenbaNeural",
433
- "ta-LK-KumarNeural",
434
- "ta-LK-SaranyaNeural",
435
- "te-IN-MohanNeural",
436
- "te-IN-ShrutiNeural",
437
- "th-TH-NiwatNeural",
438
- "th-TH-PremwadeeNeural",
439
- "tr-TR-AhmetNeural",
440
- "tr-TR-EmelNeural",
441
- "uk-UA-OstapNeural",
442
- "uk-UA-PolinaNeural",
443
- "ur-IN-GulNeural",
444
- "ur-IN-SalmanNeural",
445
- "ur-PK-AsadNeural",
446
- "ur-PK-UzmaNeural",
447
- "uz-UZ-MadinaNeural",
448
- "uz-UZ-SardorNeural",
449
- "vi-VN-HoaiMyNeural",
450
- "vi-VN-NamMinhNeural",
451
- "cy-GB-AledNeural",
452
- "cy-GB-NiaNeural",
453
- "zu-ZA-ThandoNeural",
454
- "zu-ZA-ThembaNeural"
455
- ],
456
- "google_tts_voice": [
457
- "af",
458
- "am",
459
- "ar",
460
- "bg",
461
- "bn",
462
- "bs",
463
- "ca",
464
- "cs",
465
- "cy",
466
- "da",
467
- "de",
468
- "el",
469
- "en",
470
- "es",
471
- "et",
472
- "eu",
473
- "fi",
474
- "fr",
475
- "fr-CA",
476
- "gl",
477
- "gu",
478
- "ha",
479
- "hi",
480
- "hr",
481
- "hu",
482
- "id",
483
- "is",
484
- "it",
485
- "iw",
486
- "ja",
487
- "jw",
488
- "km",
489
- "kn",
490
- "ko",
491
- "la",
492
- "lt",
493
- "lv",
494
- "ml",
495
- "mr",
496
- "ms",
497
- "my",
498
- "ne",
499
- "nl",
500
- "no",
501
- "pa",
502
- "pl",
503
- "pt",
504
- "pt-PT",
505
- "ro",
506
- "ru",
507
- "si",
508
- "sk",
509
- "sq",
510
- "sr",
511
- "su",
512
- "sv",
513
- "sw",
514
- "ta",
515
- "te",
516
- "th",
517
- "tl",
518
- "tr",
519
- "uk",
520
- "ur",
521
- "vi",
522
- "yue",
523
- "zh-CN",
524
- "zh-TW",
525
- "zh"
526
- ],
527
- "fp16": false,
528
- "separator_tab": true,
529
- "convert_tab": true,
530
- "convert_with_whisper": true,
531
- "tts_tab": true,
532
- "audioldm2": true,
533
- "effects_tab": true,
534
- "create_dataset_tab": true,
535
- "training_tab": true,
536
- "fushion_tab": true,
537
- "read_tab": true,
538
- "onnx_tab": true,
539
- "downloads_tab": true,
540
- "f0_extractor_tab": true,
541
- "settings_tab": true,
542
- "report_bug_tab": true,
543
- "font": "",
544
- "app_port": 7860,
545
- "tensorboard_port": 6870,
546
- "num_of_restart": 5,
547
- "server_name": "0.0.0.0",
548
- "app_show_error": true
549
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/configs/config.py DELETED
@@ -1,90 +0,0 @@
1
- import os
2
- import json
3
- import torch
4
-
5
-
6
- version_config_paths = [os.path.join(version, size) for version in ["v1", "v2"] for size in ["32000.json", "40000.json", "48000.json"]]
7
-
8
- def singleton(cls):
9
- instances = {}
10
-
11
- def get_instance(*args, **kwargs):
12
- if cls not in instances: instances[cls] = cls(*args, **kwargs)
13
- return instances[cls]
14
-
15
- return get_instance
16
-
17
- @singleton
18
- class Config:
19
- def __init__(self):
20
- self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
21
- self.configs = json.load(open(os.path.join("main", "configs", "config.json"), "r"))
22
- self.translations = self.multi_language()
23
- self.json_config = self.load_config_json()
24
- self.gpu_mem = None
25
- self.per_preprocess = 3.7
26
- self.is_half = self.is_fp16()
27
- self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
28
-
29
- def multi_language(self):
30
- try:
31
- lang = self.configs.get("language", "en-US")
32
- if len([l for l in os.listdir(os.path.join("assets", "languages")) if l.endswith(".json")]) < 1: raise FileNotFoundError("Không tìm thấy bất cứ gói ngôn ngữ nào(No package languages found)")
33
-
34
- if not lang: lang = "en-US"
35
- if lang not in self.configs["support_language"]: raise ValueError("Language not supported....")
36
-
37
- lang_path = os.path.join("assets", "languages", f"{lang}.json")
38
- if not os.path.exists(lang_path): lang_path = os.path.join("assets", "languages", "en-US.json")
39
-
40
- with open(lang_path, encoding="utf-8") as f:
41
- translations = json.load(f)
42
- except json.JSONDecodeError:
43
- print(self.translations["empty_json"].format(file=lang))
44
- pass
45
-
46
- return translations
47
-
48
- def is_fp16(self):
49
- fp16 = self.configs.get("fp16", False)
50
-
51
- if self.device in ["cpu", "mps"] and fp16:
52
- self.configs["fp16"] = False
53
- fp16 = False
54
-
55
- with open(os.path.join("main", "configs", "config.json"), "w") as f:
56
- json.dump(self.configs, f, indent=4)
57
-
58
- if not fp16: self.preprocess_per = 3.0
59
- return fp16
60
-
61
- def load_config_json(self):
62
- configs = {}
63
-
64
- for config_file in version_config_paths:
65
- try:
66
- with open(os.path.join("main", "configs", config_file), "r") as f:
67
- configs[config_file] = json.load(f)
68
- except json.JSONDecodeError:
69
- print(self.translations["empty_json"].format(file=config_file))
70
- pass
71
-
72
- return configs
73
-
74
- def device_config(self):
75
- if self.device.startswith("cuda"): self.set_cuda_config()
76
- elif self.has_mps(): self.device = "mps"
77
- else: self.device = "cpu"
78
-
79
- if self.gpu_mem is not None and self.gpu_mem <= 4:
80
- self.preprocess_per = 3.0
81
- return 1, 5, 30, 32
82
-
83
- return (3, 10, 60, 65) if self.is_half else (1, 6, 38, 41)
84
-
85
- def set_cuda_config(self):
86
- i_device = int(self.device.split(":")[-1])
87
- self.gpu_mem = torch.cuda.get_device_properties(i_device).total_memory // (1024**3)
88
-
89
- def has_mps(self):
90
- return torch.backends.mps.is_available()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/configs/decrypt.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:330268cbf6b9317a76510b533e1640ef48ed074a07c013e5b1abc4d48cfd9dce
3
- size 32
 
 
 
 
main/configs/v1/32000.json DELETED
@@ -1,46 +0,0 @@
1
- {
2
- "train": {
3
- "log_interval": 200,
4
- "seed": 1234,
5
- "epochs": 20000,
6
- "learning_rate": 0.0001,
7
- "betas": [0.8, 0.99],
8
- "eps": 1e-09,
9
- "batch_size": 4,
10
- "lr_decay": 0.999875,
11
- "segment_size": 12800,
12
- "init_lr_ratio": 1,
13
- "warmup_epochs": 0,
14
- "c_mel": 45,
15
- "c_kl": 1.0
16
- },
17
- "data": {
18
- "max_wav_value": 32768.0,
19
- "sample_rate": 32000,
20
- "filter_length": 1024,
21
- "hop_length": 320,
22
- "win_length": 1024,
23
- "n_mel_channels": 80,
24
- "mel_fmin": 0.0,
25
- "mel_fmax": null
26
- },
27
- "model": {
28
- "inter_channels": 192,
29
- "hidden_channels": 192,
30
- "filter_channels": 768,
31
- "text_enc_hidden_dim": 256,
32
- "n_heads": 2,
33
- "n_layers": 6,
34
- "kernel_size": 3,
35
- "p_dropout": 0,
36
- "resblock": "1",
37
- "resblock_kernel_sizes": [3, 7, 11],
38
- "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
39
- "upsample_rates": [10, 4, 2, 2, 2],
40
- "upsample_initial_channel": 512,
41
- "upsample_kernel_sizes": [16, 16, 4, 4, 4],
42
- "use_spectral_norm": false,
43
- "gin_channels": 256,
44
- "spk_embed_dim": 109
45
- }
46
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/configs/v1/40000.json DELETED
@@ -1,46 +0,0 @@
1
- {
2
- "train": {
3
- "log_interval": 200,
4
- "seed": 1234,
5
- "epochs": 20000,
6
- "learning_rate": 0.0001,
7
- "betas": [0.8, 0.99],
8
- "eps": 1e-09,
9
- "batch_size": 4,
10
- "lr_decay": 0.999875,
11
- "segment_size": 12800,
12
- "init_lr_ratio": 1,
13
- "warmup_epochs": 0,
14
- "c_mel": 45,
15
- "c_kl": 1.0
16
- },
17
- "data": {
18
- "max_wav_value": 32768.0,
19
- "sample_rate": 40000,
20
- "filter_length": 2048,
21
- "hop_length": 400,
22
- "win_length": 2048,
23
- "n_mel_channels": 125,
24
- "mel_fmin": 0.0,
25
- "mel_fmax": null
26
- },
27
- "model": {
28
- "inter_channels": 192,
29
- "hidden_channels": 192,
30
- "filter_channels": 768,
31
- "text_enc_hidden_dim": 256,
32
- "n_heads": 2,
33
- "n_layers": 6,
34
- "kernel_size": 3,
35
- "p_dropout": 0,
36
- "resblock": "1",
37
- "resblock_kernel_sizes": [3, 7, 11],
38
- "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
39
- "upsample_rates": [10, 10, 2, 2],
40
- "upsample_initial_channel": 512,
41
- "upsample_kernel_sizes": [16, 16, 4, 4],
42
- "use_spectral_norm": false,
43
- "gin_channels": 256,
44
- "spk_embed_dim": 109
45
- }
46
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/configs/v1/48000.json DELETED
@@ -1,46 +0,0 @@
1
- {
2
- "train": {
3
- "log_interval": 200,
4
- "seed": 1234,
5
- "epochs": 20000,
6
- "learning_rate": 0.0001,
7
- "betas": [0.8, 0.99],
8
- "eps": 1e-09,
9
- "batch_size": 4,
10
- "lr_decay": 0.999875,
11
- "segment_size": 11520,
12
- "init_lr_ratio": 1,
13
- "warmup_epochs": 0,
14
- "c_mel": 45,
15
- "c_kl": 1.0
16
- },
17
- "data": {
18
- "max_wav_value": 32768.0,
19
- "sample_rate": 48000,
20
- "filter_length": 2048,
21
- "hop_length": 480,
22
- "win_length": 2048,
23
- "n_mel_channels": 128,
24
- "mel_fmin": 0.0,
25
- "mel_fmax": null
26
- },
27
- "model": {
28
- "inter_channels": 192,
29
- "hidden_channels": 192,
30
- "filter_channels": 768,
31
- "text_enc_hidden_dim": 256,
32
- "n_heads": 2,
33
- "n_layers": 6,
34
- "kernel_size": 3,
35
- "p_dropout": 0,
36
- "resblock": "1",
37
- "resblock_kernel_sizes": [3, 7, 11],
38
- "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
39
- "upsample_rates": [10, 6, 2, 2, 2],
40
- "upsample_initial_channel": 512,
41
- "upsample_kernel_sizes": [16, 16, 4, 4, 4],
42
- "use_spectral_norm": false,
43
- "gin_channels": 256,
44
- "spk_embed_dim": 109
45
- }
46
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/configs/v2/32000.json DELETED
@@ -1,42 +0,0 @@
1
- {
2
- "train": {
3
- "log_interval": 200,
4
- "seed": 1234,
5
- "learning_rate": 0.0001,
6
- "betas": [0.8, 0.99],
7
- "eps": 1e-09,
8
- "lr_decay": 0.999875,
9
- "segment_size": 12800,
10
- "c_mel": 45,
11
- "c_kl": 1.0
12
- },
13
- "data": {
14
- "max_wav_value": 32768.0,
15
- "sample_rate": 32000,
16
- "filter_length": 1024,
17
- "hop_length": 320,
18
- "win_length": 1024,
19
- "n_mel_channels": 80,
20
- "mel_fmin": 0.0,
21
- "mel_fmax": null
22
- },
23
- "model": {
24
- "inter_channels": 192,
25
- "hidden_channels": 192,
26
- "filter_channels": 768,
27
- "text_enc_hidden_dim": 768,
28
- "n_heads": 2,
29
- "n_layers": 6,
30
- "kernel_size": 3,
31
- "p_dropout": 0,
32
- "resblock": "1",
33
- "resblock_kernel_sizes": [3, 7, 11],
34
- "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
35
- "upsample_rates": [10, 8, 2, 2],
36
- "upsample_initial_channel": 512,
37
- "upsample_kernel_sizes": [20, 16, 4, 4],
38
- "use_spectral_norm": false,
39
- "gin_channels": 256,
40
- "spk_embed_dim": 109
41
- }
42
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/configs/v2/40000.json DELETED
@@ -1,42 +0,0 @@
1
- {
2
- "train": {
3
- "log_interval": 200,
4
- "seed": 1234,
5
- "learning_rate": 0.0001,
6
- "betas": [0.8, 0.99],
7
- "eps": 1e-09,
8
- "lr_decay": 0.999875,
9
- "segment_size": 12800,
10
- "c_mel": 45,
11
- "c_kl": 1.0
12
- },
13
- "data": {
14
- "max_wav_value": 32768.0,
15
- "sample_rate": 40000,
16
- "filter_length": 2048,
17
- "hop_length": 400,
18
- "win_length": 2048,
19
- "n_mel_channels": 125,
20
- "mel_fmin": 0.0,
21
- "mel_fmax": null
22
- },
23
- "model": {
24
- "inter_channels": 192,
25
- "hidden_channels": 192,
26
- "filter_channels": 768,
27
- "text_enc_hidden_dim": 768,
28
- "n_heads": 2,
29
- "n_layers": 6,
30
- "kernel_size": 3,
31
- "p_dropout": 0,
32
- "resblock": "1",
33
- "resblock_kernel_sizes": [3, 7, 11],
34
- "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
35
- "upsample_rates": [10, 10, 2, 2],
36
- "upsample_initial_channel": 512,
37
- "upsample_kernel_sizes": [16, 16, 4, 4],
38
- "use_spectral_norm": false,
39
- "gin_channels": 256,
40
- "spk_embed_dim": 109
41
- }
42
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/configs/v2/48000.json DELETED
@@ -1,42 +0,0 @@
1
- {
2
- "train": {
3
- "log_interval": 200,
4
- "seed": 1234,
5
- "learning_rate": 0.0001,
6
- "betas": [0.8, 0.99],
7
- "eps": 1e-09,
8
- "lr_decay": 0.999875,
9
- "segment_size": 17280,
10
- "c_mel": 45,
11
- "c_kl": 1.0
12
- },
13
- "data": {
14
- "max_wav_value": 32768.0,
15
- "sample_rate": 48000,
16
- "filter_length": 2048,
17
- "hop_length": 480,
18
- "win_length": 2048,
19
- "n_mel_channels": 128,
20
- "mel_fmin": 0.0,
21
- "mel_fmax": null
22
- },
23
- "model": {
24
- "inter_channels": 192,
25
- "hidden_channels": 192,
26
- "filter_channels": 768,
27
- "text_enc_hidden_dim": 768,
28
- "n_heads": 2,
29
- "n_layers": 6,
30
- "kernel_size": 3,
31
- "p_dropout": 0,
32
- "resblock": "1",
33
- "resblock_kernel_sizes": [3, 7, 11],
34
- "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
35
- "upsample_rates": [12, 10, 2, 2],
36
- "upsample_initial_channel": 512,
37
- "upsample_kernel_sizes": [24, 20, 4, 4],
38
- "use_spectral_norm": false,
39
- "gin_channels": 256,
40
- "spk_embed_dim": 109
41
- }
42
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/inference/audio_effects.py DELETED
@@ -1,180 +0,0 @@
1
- import os
2
- import sys
3
- import librosa
4
- import argparse
5
-
6
- import numpy as np
7
- import soundfile as sf
8
-
9
- from distutils.util import strtobool
10
- from scipy.signal import butter, filtfilt
11
- from pedalboard import Pedalboard, Chorus, Distortion, Reverb, PitchShift, Delay, Limiter, Gain, Bitcrush, Clipping, Compressor, Phaser, HighpassFilter
12
-
13
- sys.path.append(os.getcwd())
14
-
15
- from main.configs.config import Config
16
- from main.library.utils import pydub_convert, pydub_load
17
-
18
- translations = Config().translations
19
-
20
- def parse_arguments():
21
- parser = argparse.ArgumentParser()
22
- parser.add_argument("--input_path", type=str, required=True)
23
- parser.add_argument("--output_path", type=str, default="./audios/apply_effects.wav")
24
- parser.add_argument("--export_format", type=str, default="wav")
25
- parser.add_argument("--resample", type=lambda x: bool(strtobool(x)), default=False)
26
- parser.add_argument("--resample_sr", type=int, default=0)
27
- parser.add_argument("--chorus", type=lambda x: bool(strtobool(x)), default=False)
28
- parser.add_argument("--chorus_depth", type=float, default=0.5)
29
- parser.add_argument("--chorus_rate", type=float, default=1.5)
30
- parser.add_argument("--chorus_mix", type=float, default=0.5)
31
- parser.add_argument("--chorus_delay", type=int, default=10)
32
- parser.add_argument("--chorus_feedback", type=float, default=0)
33
- parser.add_argument("--distortion", type=lambda x: bool(strtobool(x)), default=False)
34
- parser.add_argument("--drive_db", type=int, default=20)
35
- parser.add_argument("--reverb", type=lambda x: bool(strtobool(x)), default=False)
36
- parser.add_argument("--reverb_room_size", type=float, default=0.5)
37
- parser.add_argument("--reverb_damping", type=float, default=0.5)
38
- parser.add_argument("--reverb_wet_level", type=float, default=0.33)
39
- parser.add_argument("--reverb_dry_level", type=float, default=0.67)
40
- parser.add_argument("--reverb_width", type=float, default=1)
41
- parser.add_argument("--reverb_freeze_mode", type=lambda x: bool(strtobool(x)), default=False)
42
- parser.add_argument("--pitchshift", type=lambda x: bool(strtobool(x)), default=False)
43
- parser.add_argument("--pitch_shift", type=int, default=0)
44
- parser.add_argument("--delay", type=lambda x: bool(strtobool(x)), default=False)
45
- parser.add_argument("--delay_seconds", type=float, default=0.5)
46
- parser.add_argument("--delay_feedback", type=float, default=0.5)
47
- parser.add_argument("--delay_mix", type=float, default=0.5)
48
- parser.add_argument("--compressor", type=lambda x: bool(strtobool(x)), default=False)
49
- parser.add_argument("--compressor_threshold", type=int, default=-20)
50
- parser.add_argument("--compressor_ratio", type=float, default=4)
51
- parser.add_argument("--compressor_attack_ms", type=float, default=10)
52
- parser.add_argument("--compressor_release_ms", type=int, default=200)
53
- parser.add_argument("--limiter", type=lambda x: bool(strtobool(x)), default=False)
54
- parser.add_argument("--limiter_threshold", type=int, default=0)
55
- parser.add_argument("--limiter_release", type=int, default=100)
56
- parser.add_argument("--gain", type=lambda x: bool(strtobool(x)), default=False)
57
- parser.add_argument("--gain_db", type=int, default=0)
58
- parser.add_argument("--bitcrush", type=lambda x: bool(strtobool(x)), default=False)
59
- parser.add_argument("--bitcrush_bit_depth", type=int, default=16)
60
- parser.add_argument("--clipping", type=lambda x: bool(strtobool(x)), default=False)
61
- parser.add_argument("--clipping_threshold", type=int, default=-10)
62
- parser.add_argument("--phaser", type=lambda x: bool(strtobool(x)), default=False)
63
- parser.add_argument("--phaser_rate_hz", type=float, default=0.5)
64
- parser.add_argument("--phaser_depth", type=float, default=0.5)
65
- parser.add_argument("--phaser_centre_frequency_hz", type=int, default=1000)
66
- parser.add_argument("--phaser_feedback", type=float, default=0)
67
- parser.add_argument("--phaser_mix", type=float, default=0.5)
68
- parser.add_argument("--treble_bass_boost", type=lambda x: bool(strtobool(x)), default=False)
69
- parser.add_argument("--bass_boost_db", type=int, default=0)
70
- parser.add_argument("--bass_boost_frequency", type=int, default=100)
71
- parser.add_argument("--treble_boost_db", type=int, default=0)
72
- parser.add_argument("--treble_boost_frequency", type=int, default=3000)
73
- parser.add_argument("--fade_in_out", type=lambda x: bool(strtobool(x)), default=False)
74
- parser.add_argument("--fade_in_duration", type=float, default=2000)
75
- parser.add_argument("--fade_out_duration", type=float, default=2000)
76
- parser.add_argument("--audio_combination", type=lambda x: bool(strtobool(x)), default=False)
77
- parser.add_argument("--audio_combination_input", type=str)
78
-
79
- return parser.parse_args()
80
-
81
- def process_audio(input_path, output_path, resample, resample_sr, chorus_depth, chorus_rate, chorus_mix, chorus_delay, chorus_feedback, distortion_drive, reverb_room_size, reverb_damping, reverb_wet_level, reverb_dry_level, reverb_width, reverb_freeze_mode, pitch_shift, delay_seconds, delay_feedback, delay_mix, compressor_threshold, compressor_ratio, compressor_attack_ms, compressor_release_ms, limiter_threshold, limiter_release, gain_db, bitcrush_bit_depth, clipping_threshold, phaser_rate_hz, phaser_depth, phaser_centre_frequency_hz, phaser_feedback, phaser_mix, bass_boost_db, bass_boost_frequency, treble_boost_db, treble_boost_frequency, fade_in_duration, fade_out_duration, export_format, chorus, distortion, reverb, pitchshift, delay, compressor, limiter, gain, bitcrush, clipping, phaser, treble_bass_boost, fade_in_out, audio_combination, audio_combination_input):
82
- def bass_boost(audio, gain_db, frequency, sample_rate):
83
- if gain_db >= 1:
84
- b, a = butter(4, frequency / (0.5 * sample_rate), btype='low')
85
-
86
- return filtfilt(b, a, audio) * 10 ** (gain_db / 20)
87
- else: return audio
88
-
89
- def treble_boost(audio, gain_db, frequency, sample_rate):
90
- if gain_db >=1:
91
- b, a = butter(4, frequency / (0.5 * sample_rate), btype='high')
92
-
93
- return filtfilt(b, a, audio) * 10 ** (gain_db / 20)
94
- else: return audio
95
-
96
- def fade_out_effect(audio, sr, duration=3.0):
97
- length = int(duration * sr)
98
- end = audio.shape[0]
99
-
100
- if length > end: length = end
101
- start = end - length
102
-
103
- audio[start:end] = audio[start:end] * np.linspace(1.0, 0.0, length)
104
- return audio
105
-
106
- def fade_in_effect(audio, sr, duration=3.0):
107
- length = int(duration * sr)
108
- start = 0
109
-
110
- if length > audio.shape[0]: length = audio.shape[0]
111
- end = length
112
-
113
- audio[start:end] = audio[start:end] * np.linspace(0.0, 1.0, length)
114
- return audio
115
-
116
- if not input_path or not os.path.exists(input_path):
117
- print(translations["input_not_valid"])
118
- sys.exit(1)
119
-
120
- if not output_path:
121
- print(translations["output_not_valid"])
122
- sys.exit(1)
123
-
124
- if os.path.exists(output_path): os.remove(output_path)
125
-
126
- try:
127
- input_path = input_path.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
128
-
129
- try:
130
- audio, sample_rate = sf.read(input_path, dtype=np.float32)
131
- except:
132
- audio, sample_rate = librosa.load(input_path, sr=None)
133
- except Exception as e:
134
- raise RuntimeError(f"{translations['errors_loading_audio']}: {e}")
135
-
136
- audio = audio.flatten()
137
-
138
- try:
139
- board = Pedalboard([HighpassFilter()])
140
-
141
- if chorus: board.append(Chorus(depth=chorus_depth, rate_hz=chorus_rate, mix=chorus_mix, centre_delay_ms=chorus_delay, feedback=chorus_feedback))
142
- if distortion: board.append(Distortion(drive_db=distortion_drive))
143
- if reverb: board.append(Reverb(room_size=reverb_room_size, damping=reverb_damping, wet_level=reverb_wet_level, dry_level=reverb_dry_level, width=reverb_width, freeze_mode=1 if reverb_freeze_mode else 0))
144
- if pitchshift: board.append(PitchShift(semitones=pitch_shift))
145
- if delay: board.append(Delay(delay_seconds=delay_seconds, feedback=delay_feedback, mix=delay_mix))
146
- if compressor: board.append(Compressor(threshold_db=compressor_threshold, ratio=compressor_ratio, attack_ms=compressor_attack_ms, release_ms=compressor_release_ms))
147
- if limiter: board.append(Limiter(threshold_db=limiter_threshold, release_ms=limiter_release))
148
- if gain: board.append(Gain(gain_db=gain_db))
149
- if bitcrush: board.append(Bitcrush(bit_depth=bitcrush_bit_depth))
150
- if clipping: board.append(Clipping(threshold_db=clipping_threshold))
151
- if phaser: board.append(Phaser(rate_hz=phaser_rate_hz, depth=phaser_depth, centre_frequency_hz=phaser_centre_frequency_hz, feedback=phaser_feedback, mix=phaser_mix))
152
-
153
- processed_audio = board(audio, sample_rate)
154
-
155
- if treble_bass_boost:
156
- processed_audio = bass_boost(processed_audio, bass_boost_db, bass_boost_frequency, sample_rate)
157
- processed_audio = treble_boost(processed_audio, treble_boost_db, treble_boost_frequency, sample_rate)
158
-
159
- if fade_in_out:
160
- processed_audio = fade_in_effect(processed_audio, sample_rate, fade_in_duration)
161
- processed_audio = fade_out_effect(processed_audio, sample_rate, fade_out_duration)
162
-
163
- if resample_sr != sample_rate and resample_sr > 0 and resample:
164
- target_sr = min([8000, 11025, 12000, 16000, 22050, 24000, 32000, 44100, 48000, 96000], key=lambda x: abs(x - resample_sr))
165
- processed_audio = librosa.resample(processed_audio, orig_sr=sample_rate, target_sr=target_sr, res_type="soxr_vhq")
166
- sample_rate = target_sr
167
-
168
- sf.write(output_path.replace("wav", export_format), processed_audio, sample_rate, format=export_format)
169
-
170
- if audio_combination: pydub_convert(pydub_load(audio_combination_input)).overlay(pydub_convert(pydub_load(output_path.replace("wav", export_format)))).export(output_path.replace("wav", export_format), format=export_format)
171
- except Exception as e:
172
- raise RuntimeError(translations["apply_error"].format(e=e))
173
-
174
- return output_path
175
-
176
- def main():
177
- args = parse_arguments()
178
- process_audio(input_path=args.input_path, output_path=args.output_path, resample=args.resample, resample_sr=args.resample_sr, chorus_depth=args.chorus_depth, chorus_rate=args.chorus_rate, chorus_mix=args.chorus_mix, chorus_delay=args.chorus_delay, chorus_feedback=args.chorus_feedback, distortion_drive=args.drive_db, reverb_room_size=args.reverb_room_size, reverb_damping=args.reverb_damping, reverb_wet_level=args.reverb_wet_level, reverb_dry_level=args.reverb_dry_level, reverb_width=args.reverb_width, reverb_freeze_mode=args.reverb_freeze_mode, pitch_shift=args.pitch_shift, delay_seconds=args.delay_seconds, delay_feedback=args.delay_feedback, delay_mix=args.delay_mix, compressor_threshold=args.compressor_threshold, compressor_ratio=args.compressor_ratio, compressor_attack_ms=args.compressor_attack_ms, compressor_release_ms=args.compressor_release_ms, limiter_threshold=args.limiter_threshold, limiter_release=args.limiter_release, gain_db=args.gain_db, bitcrush_bit_depth=args.bitcrush_bit_depth, clipping_threshold=args.clipping_threshold, phaser_rate_hz=args.phaser_rate_hz, phaser_depth=args.phaser_depth, phaser_centre_frequency_hz=args.phaser_centre_frequency_hz, phaser_feedback=args.phaser_feedback, phaser_mix=args.phaser_mix, bass_boost_db=args.bass_boost_db, bass_boost_frequency=args.bass_boost_frequency, treble_boost_db=args.treble_boost_db, treble_boost_frequency=args.treble_boost_frequency, fade_in_duration=args.fade_in_duration, fade_out_duration=args.fade_out_duration, export_format=args.export_format, chorus=args.chorus, distortion=args.distortion, reverb=args.reverb, pitchshift=args.pitchshift, delay=args.delay, compressor=args.compressor, limiter=args.limiter, gain=args.gain, bitcrush=args.bitcrush, clipping=args.clipping, phaser=args.phaser, treble_bass_boost=args.treble_bass_boost, fade_in_out=args.fade_in_out, audio_combination=args.audio_combination, audio_combination_input=args.audio_combination_input)
179
-
180
- if __name__ == "__main__": main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/inference/audioldm2.py DELETED
@@ -1,210 +0,0 @@
1
- import os
2
- import sys
3
- import time
4
- import tqdm
5
- import torch
6
- import logging
7
- import librosa
8
- import argparse
9
- import scipy.signal
10
- import logging.handlers
11
-
12
- import numpy as np
13
- import soundfile as sf
14
-
15
- from torch import inference_mode
16
- from distutils.util import strtobool
17
-
18
- sys.path.append(os.getcwd())
19
-
20
- from main.configs.config import Config
21
- from main.library.audioldm2.utils import load_audio
22
- from main.library.audioldm2.models import load_model
23
-
24
- config = Config()
25
- translations = config.translations
26
- logger = logging.getLogger(__name__)
27
- logger.propagate = False
28
-
29
- for l in ["torch", "httpx", "httpcore", "diffusers", "transformers"]:
30
- logging.getLogger(l).setLevel(logging.ERROR)
31
-
32
- if logger.hasHandlers(): logger.handlers.clear()
33
- else:
34
- console_handler = logging.StreamHandler()
35
- console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
36
- console_handler.setFormatter(console_formatter)
37
- console_handler.setLevel(logging.INFO)
38
- file_handler = logging.handlers.RotatingFileHandler(os.path.join("assets", "logs", "audioldm2.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
39
- file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
40
- file_handler.setFormatter(file_formatter)
41
- file_handler.setLevel(logging.DEBUG)
42
- logger.addHandler(console_handler)
43
- logger.addHandler(file_handler)
44
- logger.setLevel(logging.DEBUG)
45
-
46
- def parse_arguments():
47
- parser = argparse.ArgumentParser()
48
- parser.add_argument("--input_path", type=str, required=True)
49
- parser.add_argument("--output_path", type=str, default="./output.wav")
50
- parser.add_argument("--export_format", type=str, default="wav")
51
- parser.add_argument("--sample_rate", type=int, default=44100)
52
- parser.add_argument("--audioldm_model", type=str, default="audioldm2-music")
53
- parser.add_argument("--source_prompt", type=str, default="")
54
- parser.add_argument("--target_prompt", type=str, default="")
55
- parser.add_argument("--steps", type=int, default=200)
56
- parser.add_argument("--cfg_scale_src", type=float, default=3.5)
57
- parser.add_argument("--cfg_scale_tar", type=float, default=12)
58
- parser.add_argument("--t_start", type=int, default=45)
59
- parser.add_argument("--save_compute", type=lambda x: bool(strtobool(x)), default=False)
60
-
61
- return parser.parse_args()
62
-
63
- def main():
64
- args = parse_arguments()
65
- input_path, output_path, export_format, sample_rate, audioldm_model, source_prompt, target_prompt, steps, cfg_scale_src, cfg_scale_tar, t_start, save_compute = args.input_path, args.output_path, args.export_format, args.sample_rate, args.audioldm_model, args.source_prompt, args.target_prompt, args.steps, args.cfg_scale_src, args.cfg_scale_tar, args.t_start, args.save_compute
66
-
67
- log_data = {translations['audio_path']: input_path, translations['output_path']: output_path.replace('wav', export_format), translations['model_name']: audioldm_model, translations['export_format']: export_format, translations['sample_rate']: sample_rate, translations['steps']: steps, translations['source_prompt']: source_prompt, translations['target_prompt']: target_prompt, translations['cfg_scale_src']: cfg_scale_src, translations['cfg_scale_tar']: cfg_scale_tar, translations['t_start']: t_start, translations['save_compute']: save_compute}
68
-
69
- for key, value in log_data.items():
70
- logger.debug(f"{key}: {value}")
71
-
72
- start_time = time.time()
73
- logger.info(translations["start_edit"].format(input_path=input_path))
74
- pid_path = os.path.join("assets", "audioldm2_pid.txt")
75
- with open(pid_path, "w") as pid_file:
76
- pid_file.write(str(os.getpid()))
77
-
78
- try:
79
- edit(input_path, output_path, audioldm_model, source_prompt, target_prompt, steps, cfg_scale_src, cfg_scale_tar, t_start, save_compute, sample_rate, config.device, export_format=export_format)
80
- except Exception as e:
81
- logger.error(translations["error_edit"].format(e=e))
82
- import traceback
83
- logger.debug(traceback.format_exc())
84
-
85
- logger.info(translations["edit_success"].format(time=f"{(time.time() - start_time):.2f}", output_path=output_path.replace('wav', export_format)))
86
-
87
- def invert(ldm_stable, x0, prompt_src, num_diffusion_steps, cfg_scale_src, duration, save_compute):
88
- with inference_mode():
89
- w0 = ldm_stable.vae_encode(x0)
90
-
91
- _, zs, wts, extra_info = inversion_forward_process(ldm_stable, w0, etas=1, prompts=[prompt_src], cfg_scales=[cfg_scale_src], num_inference_steps=num_diffusion_steps, numerical_fix=True, duration=duration, save_compute=save_compute)
92
- return zs, wts, extra_info
93
-
94
- def low_pass_filter(audio, cutoff=7500, sr=16000):
95
- b, a = scipy.signal.butter(4, cutoff / (sr / 2), btype='low')
96
- return scipy.signal.filtfilt(b, a, audio)
97
-
98
- def sample(output_audio, sr, ldm_stable, zs, wts, extra_info, prompt_tar, tstart, cfg_scale_tar, duration, save_compute, export_format = "wav"):
99
- tstart = torch.tensor(tstart, dtype=torch.int32)
100
- w0, _ = inversion_reverse_process(ldm_stable, xT=wts, tstart=tstart, etas=1., prompts=[prompt_tar], neg_prompts=[""], cfg_scales=[cfg_scale_tar], zs=zs[:int(tstart)], duration=duration, extra_info=extra_info, save_compute=save_compute)
101
-
102
- with inference_mode():
103
- x0_dec = ldm_stable.vae_decode(w0.to(torch.float16 if config.is_half else torch.float32))
104
-
105
- if x0_dec.dim() < 4: x0_dec = x0_dec[None, :, :, :]
106
-
107
- with torch.no_grad():
108
- audio = ldm_stable.decode_to_mel(x0_dec.to(torch.float16 if config.is_half else torch.float32))
109
-
110
- audio = audio.float().squeeze().cpu().numpy()
111
- orig_sr = 16000
112
-
113
- if sr != 16000 and sr > 0:
114
- audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=sr, res_type="soxr_vhq")
115
- orig_sr = sr
116
-
117
- audio = low_pass_filter(audio, 7500, orig_sr)
118
-
119
- sf.write(output_audio, np.tile(audio, (2, 1)).T, orig_sr, format=export_format)
120
- return output_audio
121
-
122
- def edit(input_audio, output_audio, model_id, source_prompt = "", target_prompt = "", steps = 200, cfg_scale_src = 3.5, cfg_scale_tar = 12, t_start = 45, save_compute = True, sr = 44100, device = "cpu", export_format = "wav"):
123
- ldm_stable = load_model(model_id, device=device)
124
- ldm_stable.model.scheduler.set_timesteps(steps, device=device)
125
- x0, duration = load_audio(input_audio, ldm_stable.get_melspectrogram(), device=device)
126
- zs_tensor, wts_tensor, extra_info_list = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt, num_diffusion_steps=steps, cfg_scale_src=cfg_scale_src, duration=duration, save_compute=save_compute)
127
-
128
- return sample(output_audio, sr, ldm_stable, zs_tensor, wts_tensor, extra_info_list, prompt_tar=target_prompt, tstart=int(t_start / 100 * steps), cfg_scale_tar=cfg_scale_tar, duration=duration, save_compute=save_compute, export_format=export_format)
129
-
130
- def inversion_forward_process(model, x0, etas = None, prompts = [""], cfg_scales = [3.5], num_inference_steps = 50, numerical_fix = False, duration = None, first_order = False, save_compute = True):
131
- if len(prompts) > 1 or prompts[0] != "":
132
- text_embeddings_hidden_states, text_embeddings_class_labels, text_embeddings_boolean_prompt_mask = model.encode_text(prompts)
133
- uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text([""], negative=True, save_compute=save_compute, cond_length=text_embeddings_class_labels.shape[1] if text_embeddings_class_labels is not None else None)
134
- else: uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text([""], negative=True, save_compute=False)
135
-
136
- timesteps = model.model.scheduler.timesteps.to(model.device)
137
- variance_noise_shape = model.get_noise_shape(x0, num_inference_steps)
138
-
139
- if type(etas) in [int, float]: etas = [etas]*model.model.scheduler.num_inference_steps
140
-
141
- xts = model.sample_xts_from_x0(x0, num_inference_steps=num_inference_steps)
142
- zs = torch.zeros(size=variance_noise_shape, device=model.device)
143
- extra_info = [None] * len(zs)
144
-
145
- if timesteps[0].dtype == torch.int64: t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
146
- elif timesteps[0].dtype == torch.float32: t_to_idx = {float(v): k for k, v in enumerate(timesteps)}
147
-
148
- xt = x0
149
- model.setup_extra_inputs(xt, init_timestep=timesteps[0], audio_end_in_s=duration, save_compute=save_compute and prompts[0] != "")
150
-
151
- for t in tqdm.tqdm(timesteps, desc=translations["inverting"], ncols=100, unit="a"):
152
- idx = num_inference_steps - t_to_idx[int(t) if timesteps[0].dtype == torch.int64 else float(t)] - 1
153
- xt = xts[idx + 1][None]
154
- xt_inp = model.model.scheduler.scale_model_input(xt, t).to(torch.float16 if config.is_half else torch.float32)
155
-
156
- with torch.no_grad():
157
- if save_compute and prompts[0] != "":
158
- comb_out, _, _ = model.unet_forward(xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1), timestep=t, encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states], dim=0) if uncond_embeddings_hidden_states is not None else None, class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0) if uncond_embeddings_class_lables is not None else None, encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask], dim=0) if uncond_boolean_prompt_mask is not None else None)
159
- out, cond_out = comb_out.sample.chunk(2, dim=0)
160
- else:
161
- out = model.unet_forward(xt_inp, timestep=t, encoder_hidden_states=uncond_embeddings_hidden_states, class_labels=uncond_embeddings_class_lables, encoder_attention_mask=uncond_boolean_prompt_mask)[0].sample
162
- if len(prompts) > 1 or prompts[0] != "": cond_out = model.unet_forward(xt_inp, timestep=t, encoder_hidden_states=text_embeddings_hidden_states, class_labels=text_embeddings_class_labels, encoder_attention_mask=text_embeddings_boolean_prompt_mask)[0].sample
163
-
164
- if len(prompts) > 1 or prompts[0] != "": noise_pred = out + (cfg_scales[0] * (cond_out - out)).sum(axis=0).unsqueeze(0)
165
- else: noise_pred = out
166
-
167
- xtm1 = xts[idx][None]
168
- z, xtm1, extra = model.get_zs_from_xts(xt, xtm1, noise_pred, t, eta=etas[idx], numerical_fix=numerical_fix, first_order=first_order)
169
- zs[idx] = z
170
- xts[idx] = xtm1
171
- extra_info[idx] = extra
172
-
173
- if zs is not None: zs[0] = torch.zeros_like(zs[0])
174
- return xt, zs, xts, extra_info
175
-
176
- def inversion_reverse_process(model, xT, tstart, etas = 0, prompts = [""], neg_prompts = [""], cfg_scales = None, zs = None, duration = None, first_order = False, extra_info = None, save_compute = True):
177
- text_embeddings_hidden_states, text_embeddings_class_labels, text_embeddings_boolean_prompt_mask = model.encode_text(prompts)
178
- uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text(neg_prompts, negative=True, save_compute=save_compute, cond_length=text_embeddings_class_labels.shape[1] if text_embeddings_class_labels is not None else None)
179
- xt = xT[tstart.max()].unsqueeze(0)
180
-
181
- if etas is None: etas = 0
182
- if type(etas) in [int, float]: etas = [etas]*model.model.scheduler.num_inference_steps
183
-
184
- assert len(etas) == model.model.scheduler.num_inference_steps
185
- timesteps = model.model.scheduler.timesteps.to(model.device)
186
-
187
- if timesteps[0].dtype == torch.int64: t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
188
- elif timesteps[0].dtype == torch.float32: t_to_idx = {float(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
189
-
190
- model.setup_extra_inputs(xt, extra_info=extra_info, init_timestep=timesteps[-zs.shape[0]], audio_end_in_s=duration, save_compute=save_compute)
191
-
192
- for t in tqdm.tqdm(timesteps[-zs.shape[0]:], desc=translations["editing"], ncols=100, unit="a"):
193
- idx = model.model.scheduler.num_inference_steps - t_to_idx[int(t) if timesteps[0].dtype == torch.int64 else float(t)] - (model.model.scheduler.num_inference_steps - zs.shape[0] + 1)
194
- xt_inp = model.model.scheduler.scale_model_input(xt, t).to(torch.float16 if config.is_half else torch.float32)
195
-
196
- with torch.no_grad():
197
- if save_compute:
198
- comb_out, _, _ = model.unet_forward(xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1), timestep=t, encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states], dim=0) if uncond_embeddings_hidden_states is not None else None, class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0) if uncond_embeddings_class_lables is not None else None, encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask], dim=0) if uncond_boolean_prompt_mask is not None else None)
199
- uncond_out, cond_out = comb_out.sample.chunk(2, dim=0)
200
- else:
201
- uncond_out = model.unet_forward(xt_inp, timestep=t, encoder_hidden_states=uncond_embeddings_hidden_states, class_labels=uncond_embeddings_class_lables, encoder_attention_mask=uncond_boolean_prompt_mask)[0].sample
202
- cond_out = model.unet_forward(xt_inp, timestep=t, encoder_hidden_states=text_embeddings_hidden_states, class_labels=text_embeddings_class_labels, encoder_attention_mask=text_embeddings_boolean_prompt_mask)[0].sample
203
-
204
- z = zs[idx] if zs is not None else None
205
- noise_pred = uncond_out + (cfg_scales[0] * (cond_out - uncond_out)).sum(axis=0).unsqueeze(0)
206
- xt = model.reverse_step_with_custom_noise(noise_pred, t, xt, variance_noise=z.unsqueeze(0), eta=etas[idx], first_order=first_order)
207
-
208
- return xt, zs
209
-
210
- if __name__ == "__main__": main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/inference/convert.py DELETED
@@ -1,590 +0,0 @@
1
- import re
2
- import os
3
- import gc
4
- import sys
5
- import time
6
- import faiss
7
- import torch
8
- import librosa
9
- import logging
10
- import argparse
11
- import warnings
12
- import onnxruntime
13
- import logging.handlers
14
-
15
- import numpy as np
16
- import soundfile as sf
17
- import torch.nn.functional as F
18
-
19
- from tqdm import tqdm
20
- from scipy import signal
21
- from distutils.util import strtobool
22
-
23
- warnings.filterwarnings("ignore")
24
- sys.path.append(os.getcwd())
25
-
26
- from main.configs.config import Config
27
- from main.library.algorithm.synthesizers import Synthesizer
28
- from main.library.utils import check_predictors, check_embedders, load_audio, load_embedders_model, cut, restore
29
-
30
- bh, ah = signal.butter(N=5, Wn=48, btype="high", fs=16000)
31
- config = Config()
32
- translations = config.translations
33
- logger = logging.getLogger(__name__)
34
- logger.propagate = False
35
-
36
- for l in ["torch", "faiss", "httpx", "httpcore", "faiss.loader", "numba.core", "urllib3", "transformers", "matplotlib"]:
37
- logging.getLogger(l).setLevel(logging.ERROR)
38
-
39
- if logger.hasHandlers(): logger.handlers.clear()
40
- else:
41
- console_handler = logging.StreamHandler()
42
- console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
43
- console_handler.setFormatter(console_formatter)
44
- console_handler.setLevel(logging.INFO)
45
- file_handler = logging.handlers.RotatingFileHandler(os.path.join("assets", "logs", "convert.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
46
- file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
47
- file_handler.setFormatter(file_formatter)
48
- file_handler.setLevel(logging.DEBUG)
49
- logger.addHandler(console_handler)
50
- logger.addHandler(file_handler)
51
- logger.setLevel(logging.DEBUG)
52
-
53
- def parse_arguments():
54
- parser = argparse.ArgumentParser()
55
- parser.add_argument("--pitch", type=int, default=0)
56
- parser.add_argument("--filter_radius", type=int, default=3)
57
- parser.add_argument("--index_rate", type=float, default=0.5)
58
- parser.add_argument("--volume_envelope", type=float, default=1)
59
- parser.add_argument("--protect", type=float, default=0.33)
60
- parser.add_argument("--hop_length", type=int, default=64)
61
- parser.add_argument("--f0_method", type=str, default="rmvpe")
62
- parser.add_argument("--embedder_model", type=str, default="contentvec_base")
63
- parser.add_argument("--input_path", type=str, required=True)
64
- parser.add_argument("--output_path", type=str, default="./audios/output.wav")
65
- parser.add_argument("--export_format", type=str, default="wav")
66
- parser.add_argument("--pth_path", type=str, required=True)
67
- parser.add_argument("--index_path", type=str)
68
- parser.add_argument("--f0_autotune", type=lambda x: bool(strtobool(x)), default=False)
69
- parser.add_argument("--f0_autotune_strength", type=float, default=1)
70
- parser.add_argument("--clean_audio", type=lambda x: bool(strtobool(x)), default=False)
71
- parser.add_argument("--clean_strength", type=float, default=0.7)
72
- parser.add_argument("--resample_sr", type=int, default=0)
73
- parser.add_argument("--split_audio", type=lambda x: bool(strtobool(x)), default=False)
74
- parser.add_argument("--checkpointing", type=lambda x: bool(strtobool(x)), default=False)
75
- parser.add_argument("--f0_file", type=str, default="")
76
- parser.add_argument("--f0_onnx", type=lambda x: bool(strtobool(x)), default=False)
77
- parser.add_argument("--embedders_mode", type=str, default="fairseq")
78
- parser.add_argument("--formant_shifting", type=lambda x: bool(strtobool(x)), default=False)
79
- parser.add_argument("--formant_qfrency", type=float, default=0.8)
80
- parser.add_argument("--formant_timbre", type=float, default=0.8)
81
-
82
- return parser.parse_args()
83
-
84
- def main():
85
- args = parse_arguments()
86
- pitch, filter_radius, index_rate, volume_envelope, protect, hop_length, f0_method, input_path, output_path, pth_path, index_path, f0_autotune, f0_autotune_strength, clean_audio, clean_strength, export_format, embedder_model, resample_sr, split_audio, checkpointing, f0_file, f0_onnx, embedders_mode, formant_shifting, formant_qfrency, formant_timbre = args.pitch, args.filter_radius, args.index_rate, args.volume_envelope,args.protect, args.hop_length, args.f0_method, args.input_path, args.output_path, args.pth_path, args.index_path, args.f0_autotune, args.f0_autotune_strength, args.clean_audio, args.clean_strength, args.export_format, args.embedder_model, args.resample_sr, args.split_audio, args.checkpointing, args.f0_file, args.f0_onnx, args.embedders_mode, args.formant_shifting, args.formant_qfrency, args.formant_timbre
87
-
88
- log_data = {translations['pitch']: pitch, translations['filter_radius']: filter_radius, translations['index_strength']: index_rate, translations['volume_envelope']: volume_envelope, translations['protect']: protect, "Hop length": hop_length, translations['f0_method']: f0_method, translations['audio_path']: input_path, translations['output_path']: output_path.replace('wav', export_format), translations['model_path']: pth_path, translations['indexpath']: index_path, translations['autotune']: f0_autotune, translations['clear_audio']: clean_audio, translations['export_format']: export_format, translations['hubert_model']: embedder_model, translations['split_audio']: split_audio, translations['memory_efficient_training']: checkpointing, translations["f0_onnx_mode"]: f0_onnx, translations["embed_mode"]: embedders_mode}
89
-
90
- if clean_audio: log_data[translations['clean_strength']] = clean_strength
91
- if resample_sr != 0: log_data[translations['sample_rate']] = resample_sr
92
-
93
- if f0_autotune: log_data[translations['autotune_rate_info']] = f0_autotune_strength
94
- if os.path.isfile(f0_file): log_data[translations['f0_file']] = f0_file
95
-
96
- if formant_shifting:
97
- log_data[translations['formant_qfrency']] = formant_qfrency
98
- log_data[translations['formant_timbre']] = formant_timbre
99
-
100
- for key, value in log_data.items():
101
- logger.debug(f"{key}: {value}")
102
-
103
- run_convert_script(pitch=pitch, filter_radius=filter_radius, index_rate=index_rate, volume_envelope=volume_envelope, protect=protect, hop_length=hop_length, f0_method=f0_method, input_path=input_path, output_path=output_path, pth_path=pth_path, index_path=index_path, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, clean_audio=clean_audio, clean_strength=clean_strength, export_format=export_format, embedder_model=embedder_model, resample_sr=resample_sr, split_audio=split_audio, checkpointing=checkpointing, f0_file=f0_file, f0_onnx=f0_onnx, embedders_mode=embedders_mode, formant_shifting=formant_shifting, formant_qfrency=formant_qfrency, formant_timbre=formant_timbre)
104
-
105
- def run_convert_script(pitch=0, filter_radius=3, index_rate=0.5, volume_envelope=1, protect=0.5, hop_length=64, f0_method="rmvpe", input_path=None, output_path="./output.wav", pth_path=None, index_path=None, f0_autotune=False, f0_autotune_strength=1, clean_audio=False, clean_strength=0.7, export_format="wav", embedder_model="contentvec_base", resample_sr=0, split_audio=False, checkpointing=False, f0_file=None, f0_onnx=False, embedders_mode="fairseq", formant_shifting=False, formant_qfrency=0.8, formant_timbre=0.8):
106
- check_predictors(f0_method, f0_onnx); check_embedders(embedder_model, embedders_mode)
107
-
108
- if not pth_path or not os.path.exists(pth_path) or os.path.isdir(pth_path) or not pth_path.endswith((".pth", ".onnx")):
109
- logger.warning(translations["provide_file"].format(filename=translations["model"]))
110
- sys.exit(1)
111
-
112
- cvt = VoiceConverter(pth_path, 0)
113
- start_time = time.time()
114
-
115
- pid_path = os.path.join("assets", "convert_pid.txt")
116
- with open(pid_path, "w") as pid_file:
117
- pid_file.write(str(os.getpid()))
118
-
119
- if os.path.isdir(input_path):
120
- logger.info(translations["convert_batch"])
121
- audio_files = [f for f in os.listdir(input_path) if f.lower().endswith(("wav", "mp3", "flac", "ogg", "opus", "m4a", "mp4", "aac", "alac", "wma", "aiff", "webm", "ac3"))]
122
-
123
- if not audio_files:
124
- logger.warning(translations["not_found_audio"])
125
- sys.exit(1)
126
-
127
- logger.info(translations["found_audio"].format(audio_files=len(audio_files)))
128
-
129
- for audio in audio_files:
130
- audio_path = os.path.join(input_path, audio)
131
- output_audio = os.path.join(input_path, os.path.splitext(audio)[0] + f"_output.{export_format}")
132
-
133
- logger.info(f"{translations['convert_audio']} '{audio_path}'...")
134
- if os.path.exists(output_audio): os.remove(output_audio)
135
- cvt.convert_audio(pitch=pitch, filter_radius=filter_radius, index_rate=index_rate, volume_envelope=volume_envelope, protect=protect, hop_length=hop_length, f0_method=f0_method, audio_input_path=audio_path, audio_output_path=output_audio, index_path=index_path, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, clean_audio=clean_audio, clean_strength=clean_strength, export_format=export_format, embedder_model=embedder_model, resample_sr=resample_sr, checkpointing=checkpointing, f0_file=f0_file, f0_onnx=f0_onnx, embedders_mode=embedders_mode, formant_shifting=formant_shifting, formant_qfrency=formant_qfrency, formant_timbre=formant_timbre, split_audio=split_audio)
136
-
137
- logger.info(translations["convert_batch_success"].format(elapsed_time=f"{(time.time() - start_time):.2f}", output_path=output_path.replace('wav', export_format)))
138
- else:
139
- if not os.path.exists(input_path):
140
- logger.warning(translations["not_found_audio"])
141
- sys.exit(1)
142
-
143
- logger.info(f"{translations['convert_audio']} '{input_path}'...")
144
- if os.path.exists(output_path): os.remove(output_path)
145
- cvt.convert_audio(pitch=pitch, filter_radius=filter_radius, index_rate=index_rate, volume_envelope=volume_envelope, protect=protect, hop_length=hop_length, f0_method=f0_method, audio_input_path=input_path, audio_output_path=output_path, index_path=index_path, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, clean_audio=clean_audio, clean_strength=clean_strength, export_format=export_format, embedder_model=embedder_model, resample_sr=resample_sr, checkpointing=checkpointing, f0_file=f0_file, f0_onnx=f0_onnx, embedders_mode=embedders_mode, formant_shifting=formant_shifting, formant_qfrency=formant_qfrency, formant_timbre=formant_timbre, split_audio=split_audio)
146
-
147
- if os.path.exists(pid_path): os.remove(pid_path)
148
- logger.info(translations["convert_audio_success"].format(input_path=input_path, elapsed_time=f"{(time.time() - start_time):.2f}", output_path=output_path.replace('wav', export_format)))
149
-
150
- def change_rms(source_audio, source_rate, target_audio, target_rate, rate):
151
- rms2 = F.interpolate(torch.from_numpy(librosa.feature.rms(y=target_audio, frame_length=target_rate // 2 * 2, hop_length=target_rate // 2)).float().unsqueeze(0), size=target_audio.shape[0], mode="linear").squeeze()
152
- return (target_audio * (torch.pow(F.interpolate(torch.from_numpy(librosa.feature.rms(y=source_audio, frame_length=source_rate // 2 * 2, hop_length=source_rate // 2)).float().unsqueeze(0), size=target_audio.shape[0], mode="linear").squeeze(), 1 - rate) * torch.pow(torch.maximum(rms2, torch.zeros_like(rms2) + 1e-6), rate - 1)).numpy())
153
-
154
- def clear_gpu_cache():
155
- gc.collect()
156
- if torch.cuda.is_available(): torch.cuda.empty_cache()
157
- elif torch.backends.mps.is_available(): torch.mps.empty_cache()
158
-
159
- def get_providers():
160
- ort_providers = onnxruntime.get_available_providers()
161
-
162
- if "CUDAExecutionProvider" in ort_providers: providers = ["CUDAExecutionProvider"]
163
- elif "CoreMLExecutionProvider" in ort_providers: providers = ["CoreMLExecutionProvider"]
164
- else: providers = ["CPUExecutionProvider"]
165
-
166
- return providers
167
-
168
- class Autotune:
169
- def __init__(self, ref_freqs):
170
- self.ref_freqs = ref_freqs
171
- self.note_dict = self.ref_freqs
172
-
173
- def autotune_f0(self, f0, f0_autotune_strength):
174
- autotuned_f0 = np.zeros_like(f0)
175
-
176
- for i, freq in enumerate(f0):
177
- autotuned_f0[i] = freq + (min(self.note_dict, key=lambda x: abs(x - freq)) - freq) * f0_autotune_strength
178
-
179
- return autotuned_f0
180
-
181
- class VC:
182
- def __init__(self, tgt_sr, config):
183
- self.x_pad = config.x_pad
184
- self.x_query = config.x_query
185
- self.x_center = config.x_center
186
- self.x_max = config.x_max
187
- self.sample_rate = 16000
188
- self.window = 160
189
- self.t_pad = self.sample_rate * self.x_pad
190
- self.t_pad_tgt = tgt_sr * self.x_pad
191
- self.t_pad2 = self.t_pad * 2
192
- self.t_query = self.sample_rate * self.x_query
193
- self.t_center = self.sample_rate * self.x_center
194
- self.t_max = self.sample_rate * self.x_max
195
- self.time_step = self.window / self.sample_rate * 1000
196
- self.f0_min = 50
197
- self.f0_max = 1100
198
- self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700)
199
- self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700)
200
- self.device = config.device
201
- self.is_half = config.is_half
202
- self.ref_freqs = [49.00, 51.91, 55.00, 58.27, 61.74, 65.41, 69.30, 73.42, 77.78, 82.41, 87.31, 92.50, 98.00, 103.83, 110.00, 116.54, 123.47, 130.81, 138.59, 146.83, 155.56, 164.81, 174.61, 185.00, 196.00, 207.65, 220.00, 233.08, 246.94, 261.63, 277.18, 293.66, 311.13, 329.63, 349.23, 369.99, 392.00, 415.30, 440.00, 466.16, 493.88, 523.25, 554.37, 587.33, 622.25, 659.25, 698.46, 739.99, 783.99, 830.61, 880.00, 932.33, 987.77, 1046.50]
203
- self.autotune = Autotune(self.ref_freqs)
204
- self.note_dict = self.autotune.note_dict
205
-
206
- def get_f0_pm(self, x, p_len):
207
- import parselmouth
208
-
209
- f0 = (parselmouth.Sound(x, self.sample_rate).to_pitch_ac(time_step=self.window / self.sample_rate * 1000 / 1000, voicing_threshold=0.6, pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array["frequency"])
210
- pad_size = (p_len - len(f0) + 1) // 2
211
-
212
- if pad_size > 0 or p_len - len(f0) - pad_size > 0: f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
213
- return f0
214
-
215
- def get_f0_mangio_crepe(self, x, p_len, hop_length, model="full", onnx=False):
216
- from main.library.predictors.CREPE import predict
217
-
218
- x = x.astype(np.float32)
219
- x /= np.quantile(np.abs(x), 0.999)
220
-
221
- audio = torch.unsqueeze(torch.from_numpy(x).to(self.device, copy=True), dim=0)
222
- if audio.ndim == 2 and audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True).detach()
223
-
224
- p_len = p_len or x.shape[0] // hop_length
225
- source = np.array(predict(audio.detach(), self.sample_rate, hop_length, self.f0_min, self.f0_max, model, batch_size=hop_length * 2, device=self.device, pad=True, providers=get_providers(), onnx=onnx).squeeze(0).cpu().float().numpy())
226
- source[source < 0.001] = np.nan
227
-
228
- return np.nan_to_num(np.interp(np.arange(0, len(source) * p_len, len(source)) / p_len, np.arange(0, len(source)), source))
229
-
230
- def get_f0_crepe(self, x, model="full", onnx=False):
231
- from main.library.predictors.CREPE import predict, mean, median
232
-
233
- f0, pd = predict(torch.tensor(np.copy(x))[None].float(), self.sample_rate, self.window, self.f0_min, self.f0_max, model, batch_size=512, device=self.device, return_periodicity=True, providers=get_providers(), onnx=onnx)
234
- f0, pd = mean(f0, 3), median(pd, 3)
235
- f0[pd < 0.1] = 0
236
-
237
- return f0[0].cpu().numpy()
238
-
239
- def get_f0_fcpe(self, x, p_len, hop_length, onnx=False, legacy=False):
240
- from main.library.predictors.FCPE import FCPE
241
-
242
- model_fcpe = FCPE(os.path.join("assets", "models", "predictors", ("fcpe_legacy" if legacy else "fcpe") + (".onnx" if onnx else ".pt")), hop_length=int(hop_length), f0_min=int(self.f0_min), f0_max=int(self.f0_max), dtype=torch.float32, device=self.device, sample_rate=self.sample_rate, threshold=0.03 if legacy else 0.006, providers=get_providers(), onnx=onnx, legacy=legacy)
243
- f0 = model_fcpe.compute_f0(x, p_len=p_len)
244
-
245
- del model_fcpe
246
- return f0
247
-
248
- def get_f0_rmvpe(self, x, legacy=False, onnx=False):
249
- from main.library.predictors.RMVPE import RMVPE
250
-
251
- rmvpe_model = RMVPE(os.path.join("assets", "models", "predictors", "rmvpe" + (".onnx" if onnx else ".pt")), is_half=self.is_half, device=self.device, onnx=onnx, providers=get_providers())
252
- f0 = rmvpe_model.infer_from_audio_with_pitch(x, thred=0.03, f0_min=self.f0_min, f0_max=self.f0_max) if legacy else rmvpe_model.infer_from_audio(x, thred=0.03)
253
-
254
- del rmvpe_model
255
- return f0
256
-
257
- def get_f0_pyworld(self, x, filter_radius, model="harvest"):
258
- from main.library.predictors.WORLD_WRAPPER import PYWORLD
259
-
260
- pw = PYWORLD()
261
- x = x.astype(np.double)
262
-
263
- if model == "harvest": f0, t = pw.harvest(x, fs=self.sample_rate, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=10)
264
- elif model == "dio": f0, t = pw.dio(x, fs=self.sample_rate, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=10)
265
- else: raise ValueError(translations["method_not_valid"])
266
-
267
- f0 = pw.stonemask(x, self.sample_rate, t, f0)
268
-
269
- if filter_radius > 2 or model == "dio": f0 = signal.medfilt(f0, filter_radius)
270
- return f0
271
-
272
- def get_f0_swipe(self, x):
273
- from main.library.predictors.SWIPE import swipe
274
-
275
- f0, _ = swipe(x.astype(np.float32), self.sample_rate, f0_floor=self.f0_min, f0_ceil=self.f0_max, frame_period=10)
276
- return f0
277
-
278
- def get_f0_yin(self, x, hop_length, p_len, mode="yin"):
279
- source = np.array(librosa.yin(x.astype(np.float32), sr=self.sample_rate, fmin=self.f0_min, fmax=self.f0_max, hop_length=hop_length) if mode == "yin" else librosa.pyin(x.astype(np.float32), fmin=self.f0_min, fmax=self.f0_max, sr=self.sample_rate, hop_length=hop_length)[0])
280
- source[source < 0.001] = np.nan
281
- return np.nan_to_num(np.interp(np.arange(0, len(source) * p_len, len(source)) / p_len, np.arange(0, len(source)), source))
282
-
283
- def get_f0_hybrid(self, methods_str, x, p_len, hop_length, filter_radius, onnx_mode):
284
- methods_str = re.search("hybrid\[(.+)\]", methods_str)
285
- if methods_str: methods = [method.strip() for method in methods_str.group(1).split("+")]
286
-
287
- f0_computation_stack, resampled_stack = [], []
288
- logger.debug(translations["hybrid_methods"].format(methods=methods))
289
-
290
- x = x.astype(np.float32)
291
- x /= np.quantile(np.abs(x), 0.999)
292
-
293
- for method in methods:
294
- f0 = None
295
- f0_methods = {"pm": lambda: self.get_f0_pm(x, p_len), "dio": lambda: self.get_f0_pyworld(x, filter_radius, "dio"), "mangio-crepe-tiny": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "tiny", onnx=onnx_mode), "mangio-crepe-small": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "small", onnx=onnx_mode), "mangio-crepe-medium": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "medium", onnx=onnx_mode), "mangio-crepe-large": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "large", onnx=onnx_mode), "mangio-crepe-full": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "full", onnx=onnx_mode), "crepe-tiny": lambda: self.get_f0_crepe(x, "tiny", onnx=onnx_mode), "crepe-small": lambda: self.get_f0_crepe(x, "small", onnx=onnx_mode), "crepe-medium": lambda: self.get_f0_crepe(x, "medium", onnx=onnx_mode), "crepe-large": lambda: self.get_f0_crepe(x, "large", onnx=onnx_mode), "crepe-full": lambda: self.get_f0_crepe(x, "full", onnx=onnx_mode), "fcpe": lambda: self.get_f0_fcpe(x, p_len, int(hop_length), onnx=onnx_mode), "fcpe-legacy": lambda: self.get_f0_fcpe(x, p_len, int(hop_length), legacy=True, onnx=onnx_mode), "rmvpe": lambda: self.get_f0_rmvpe(x, onnx=onnx_mode), "rmvpe-legacy": lambda: self.get_f0_rmvpe(x, legacy=True, onnx=onnx_mode), "harvest": lambda: self.get_f0_pyworld(x, filter_radius, "harvest"), "yin": lambda: self.get_f0_yin(x, int(hop_length), p_len, mode="yin"), "pyin": lambda: self.get_f0_yin(x, int(hop_length), p_len, mode="pyin"), "swipe": lambda: self.get_f0_swipe(x)}
296
- f0 = f0_methods.get(method, lambda: ValueError(translations["method_not_valid"]))()
297
- f0_computation_stack.append(f0)
298
-
299
- for f0 in f0_computation_stack:
300
- resampled_stack.append(np.interp(np.linspace(0, len(f0), p_len), np.arange(len(f0)), f0))
301
-
302
- return resampled_stack[0] if len(resampled_stack) == 1 else np.nanmedian(np.vstack(resampled_stack), axis=0)
303
-
304
- def get_f0(self, x, p_len, pitch, f0_method, filter_radius, hop_length, f0_autotune, f0_autotune_strength, inp_f0=None, onnx_mode=False):
305
- f0_methods = {"pm": lambda: self.get_f0_pm(x, p_len), "dio": lambda: self.get_f0_pyworld(x, filter_radius, "dio"), "mangio-crepe-tiny": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "tiny", onnx=onnx_mode), "mangio-crepe-small": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "small", onnx=onnx_mode), "mangio-crepe-medium": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "medium", onnx=onnx_mode), "mangio-crepe-large": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "large", onnx=onnx_mode), "mangio-crepe-full": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "full", onnx=onnx_mode), "crepe-tiny": lambda: self.get_f0_crepe(x, "tiny", onnx=onnx_mode), "crepe-small": lambda: self.get_f0_crepe(x, "small", onnx=onnx_mode), "crepe-medium": lambda: self.get_f0_crepe(x, "medium", onnx=onnx_mode), "crepe-large": lambda: self.get_f0_crepe(x, "large", onnx=onnx_mode), "crepe-full": lambda: self.get_f0_crepe(x, "full", onnx=onnx_mode), "fcpe": lambda: self.get_f0_fcpe(x, p_len, int(hop_length), onnx=onnx_mode), "fcpe-legacy": lambda: self.get_f0_fcpe(x, p_len, int(hop_length), legacy=True, onnx=onnx_mode), "rmvpe": lambda: self.get_f0_rmvpe(x, onnx=onnx_mode), "rmvpe-legacy": lambda: self.get_f0_rmvpe(x, legacy=True, onnx=onnx_mode), "harvest": lambda: self.get_f0_pyworld(x, filter_radius, "harvest"), "yin": lambda: self.get_f0_yin(x, int(hop_length), p_len, mode="yin"), "pyin": lambda: self.get_f0_yin(x, int(hop_length), p_len, mode="pyin"), "swipe": lambda: self.get_f0_swipe(x)}
306
- f0 = self.get_f0_hybrid(f0_method, x, p_len, hop_length, filter_radius, onnx_mode) if "hybrid" in f0_method else f0_methods.get(f0_method, lambda: ValueError(translations["method_not_valid"]))()
307
-
308
- if f0_autotune: f0 = Autotune.autotune_f0(self, f0, f0_autotune_strength)
309
- if isinstance(f0, tuple): f0 = f0[0]
310
-
311
- f0 *= pow(2, pitch / 12)
312
- tf0 = self.sample_rate // self.window
313
-
314
- if inp_f0 is not None:
315
- replace_f0 = np.interp(list(range(np.round((inp_f0[:, 0].max() - inp_f0[:, 0].min()) * tf0 + 1).astype(np.int16))), inp_f0[:, 0] * 100, inp_f0[:, 1])
316
- f0[self.x_pad * tf0 : self.x_pad * tf0 + len(replace_f0)] = replace_f0[:f0[self.x_pad * tf0 : self.x_pad * tf0 + len(replace_f0)].shape[0]]
317
-
318
- f0_mel = 1127 * np.log(1 + f0 / 700)
319
- f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - self.f0_mel_min) * 254 / (self.f0_mel_max - self.f0_mel_min) + 1
320
- f0_mel[f0_mel <= 1] = 1
321
- f0_mel[f0_mel > 255] = 255
322
-
323
- return np.rint(f0_mel).astype(np.int32), f0.copy()
324
-
325
- def extract_features(self, model, feats, version):
326
- return torch.as_tensor(model.run([model.get_outputs()[0].name, model.get_outputs()[1].name], {"feats": feats.detach().cpu().numpy()})[0 if version == "v1" else 1], dtype=torch.float32, device=feats.device)
327
-
328
- def voice_conversion(self, model, net_g, sid, audio0, pitch, pitchf, index, big_npy, index_rate, version, protect):
329
- pitch_guidance = pitch != None and pitchf != None
330
- feats = (torch.from_numpy(audio0).half() if self.is_half else torch.from_numpy(audio0).float())
331
-
332
- if feats.dim() == 2: feats = feats.mean(-1)
333
- assert feats.dim() == 1, feats.dim()
334
- feats = feats.view(1, -1)
335
-
336
- with torch.no_grad():
337
- if self.embed_suffix == ".pt":
338
- padding_mask = torch.BoolTensor(feats.shape).to(self.device).fill_(False)
339
- logits = model.extract_features(**{"source": feats.to(self.device), "padding_mask": padding_mask, "output_layer": 9 if version == "v1" else 12})
340
- feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
341
- elif self.embed_suffix == ".onnx": feats = self.extract_features(model, feats.to(self.device), version).to(self.device)
342
- elif self.embed_suffix == ".safetensors":
343
- logits = model(feats.to(self.device))["last_hidden_state"]
344
- feats = (model.final_proj(logits[0]).unsqueeze(0) if version == "v1" else logits)
345
- else: raise ValueError(translations["option_not_valid"])
346
-
347
- if protect < 0.5 and pitch_guidance: feats0 = feats.clone()
348
-
349
- if (not isinstance(index, type(None)) and not isinstance(big_npy, type(None)) and index_rate != 0):
350
- npy = feats[0].cpu().numpy()
351
- if self.is_half: npy = npy.astype(np.float32)
352
-
353
- score, ix = index.search(npy, k=8)
354
- weight = np.square(1 / score)
355
-
356
- npy = np.sum(big_npy[ix] * np.expand_dims(weight / weight.sum(axis=1, keepdims=True), axis=2), axis=1)
357
- if self.is_half: npy = npy.astype(np.float16)
358
-
359
- feats = (torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate + (1 - index_rate) * feats)
360
-
361
- feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
362
- if protect < 0.5 and pitch_guidance: feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
363
-
364
- p_len = audio0.shape[0] // self.window
365
-
366
- if feats.shape[1] < p_len:
367
- p_len = feats.shape[1]
368
- if pitch_guidance:
369
- pitch = pitch[:, :p_len]
370
- pitchf = pitchf[:, :p_len]
371
-
372
- if protect < 0.5 and pitch_guidance:
373
- pitchff = pitchf.clone()
374
- pitchff[pitchf > 0] = 1
375
- pitchff[pitchf < 1] = protect
376
- pitchff = pitchff.unsqueeze(-1)
377
-
378
- feats = (feats * pitchff + feats0 * (1 - pitchff)).to(feats0.dtype)
379
-
380
- p_len = torch.tensor([p_len], device=self.device).long()
381
- audio1 = ((net_g.infer(feats.half() if self.is_half else feats.float(), p_len, pitch if pitch_guidance else None, (pitchf.half() if self.is_half else pitchf.float()) if pitch_guidance else None, sid)[0][0, 0]).data.cpu().float().numpy()) if self.suffix == ".pth" else (net_g.run([net_g.get_outputs()[0].name], ({net_g.get_inputs()[0].name: feats.cpu().numpy().astype(np.float32), net_g.get_inputs()[1].name: p_len.cpu().numpy(), net_g.get_inputs()[2].name: np.array([sid.cpu().item()], dtype=np.int64), net_g.get_inputs()[3].name: np.random.randn(1, 192, p_len).astype(np.float32), net_g.get_inputs()[4].name: pitch.cpu().numpy().astype(np.int64), net_g.get_inputs()[5].name: pitchf.cpu().numpy().astype(np.float32)} if pitch_guidance else {net_g.get_inputs()[0].name: feats.cpu().numpy().astype(np.float32), net_g.get_inputs()[1].name: p_len.cpu().numpy(), net_g.get_inputs()[2].name: np.array([sid.cpu().item()], dtype=np.int64), net_g.get_inputs()[3].name: np.random.randn(1, 192, p_len).astype(np.float32)}))[0][0, 0])
382
-
383
- if self.embed_suffix == ".pt": del padding_mask
384
- del feats, p_len, net_g
385
- clear_gpu_cache()
386
- return audio1
387
-
388
- def pipeline(self, model, net_g, sid, audio, pitch, f0_method, file_index, index_rate, pitch_guidance, filter_radius, volume_envelope, version, protect, hop_length, f0_autotune, f0_autotune_strength, suffix, embed_suffix, f0_file=None, f0_onnx=False, pbar=None):
389
- self.suffix = suffix
390
- self.embed_suffix = embed_suffix
391
-
392
- if file_index != "" and os.path.exists(file_index) and index_rate != 0:
393
- try:
394
- index = faiss.read_index(file_index)
395
- big_npy = index.reconstruct_n(0, index.ntotal)
396
- except Exception as e:
397
- logger.error(translations["read_faiss_index_error"].format(e=e))
398
- index = big_npy = None
399
- else: index = big_npy = None
400
-
401
- pbar.update(1)
402
- opt_ts, audio_opt = [], []
403
- audio = signal.filtfilt(bh, ah, audio)
404
- audio_pad = np.pad(audio, (self.window // 2, self.window // 2), mode="reflect")
405
-
406
- if audio_pad.shape[0] > self.t_max:
407
- audio_sum = np.zeros_like(audio)
408
- for i in range(self.window):
409
- audio_sum += audio_pad[i : i - self.window]
410
-
411
- for t in range(self.t_center, audio.shape[0], self.t_center):
412
- opt_ts.append(t - self.t_query + np.where(np.abs(audio_sum[t - self.t_query : t + self.t_query]) == np.abs(audio_sum[t - self.t_query : t + self.t_query]).min())[0][0])
413
-
414
- s = 0
415
- t, inp_f0 = None, None
416
- audio_pad = np.pad(audio, (self.t_pad, self.t_pad), mode="reflect")
417
- sid = torch.tensor(sid, device=self.device).unsqueeze(0).long()
418
- p_len = audio_pad.shape[0] // self.window
419
-
420
- if hasattr(f0_file, "name"):
421
- try:
422
- with open(f0_file.name, "r") as f:
423
- raw_lines = f.read()
424
- if len(raw_lines) > 0:
425
- inp_f0 = []
426
- for line in raw_lines.strip("\n").split("\n"):
427
- inp_f0.append([float(i) for i in line.split(",")])
428
-
429
- inp_f0 = np.array(inp_f0, dtype=np.float32)
430
- except:
431
- logger.error(translations["error_readfile"])
432
- inp_f0 = None
433
-
434
- pbar.update(1)
435
- if pitch_guidance:
436
- pitch, pitchf = self.get_f0(audio_pad, p_len, pitch, f0_method, filter_radius, hop_length, f0_autotune, f0_autotune_strength, inp_f0, onnx_mode=f0_onnx)
437
- pitch, pitchf = pitch[:p_len], pitchf[:p_len]
438
- if self.device == "mps": pitchf = pitchf.astype(np.float32)
439
- pitch, pitchf = torch.tensor(pitch, device=self.device).unsqueeze(0).long(), torch.tensor(pitchf, device=self.device).unsqueeze(0).float()
440
-
441
- pbar.update(1)
442
- for t in opt_ts:
443
- t = t // self.window * self.window
444
- audio_opt.append(self.voice_conversion(model, net_g, sid, audio_pad[s : t + self.t_pad2 + self.window], pitch[:, s // self.window : (t + self.t_pad2) // self.window] if pitch_guidance else None, pitchf[:, s // self.window : (t + self.t_pad2) // self.window] if pitch_guidance else None, index, big_npy, index_rate, version, protect)[self.t_pad_tgt : -self.t_pad_tgt])
445
- s = t
446
-
447
- audio_opt.append(self.voice_conversion(model, net_g, sid, audio_pad[t:], (pitch[:, t // self.window :] if t is not None else pitch) if pitch_guidance else None, (pitchf[:, t // self.window :] if t is not None else pitchf) if pitch_guidance else None, index, big_npy, index_rate, version, protect)[self.t_pad_tgt : -self.t_pad_tgt])
448
- audio_opt = np.concatenate(audio_opt)
449
- if volume_envelope != 1: audio_opt = change_rms(audio, self.sample_rate, audio_opt, self.sample_rate, volume_envelope)
450
- audio_max = np.abs(audio_opt).max() / 0.99
451
- if audio_max > 1: audio_opt /= audio_max
452
-
453
- if pitch_guidance: del pitch, pitchf
454
- del sid
455
- clear_gpu_cache()
456
- pbar.update(1)
457
-
458
- return audio_opt
459
-
460
- class VoiceConverter:
461
- def __init__(self, model_path, sid = 0):
462
- self.config = config
463
- self.device = config.device
464
- self.hubert_model = None
465
- self.tgt_sr = None
466
- self.net_g = None
467
- self.vc = None
468
- self.cpt = None
469
- self.version = None
470
- self.n_spk = None
471
- self.use_f0 = None
472
- self.loaded_model = None
473
- self.vocoder = "Default"
474
- self.checkpointing = False
475
- self.sample_rate = 16000
476
- self.sid = sid
477
- self.get_vc(model_path, sid)
478
-
479
- def convert_audio(self, audio_input_path, audio_output_path, index_path, embedder_model, pitch, f0_method, index_rate, volume_envelope, protect, hop_length, f0_autotune, f0_autotune_strength, filter_radius, clean_audio, clean_strength, export_format, resample_sr = 0, checkpointing = False, f0_file = None, f0_onnx = False, embedders_mode = "fairseq", formant_shifting = False, formant_qfrency = 0.8, formant_timbre = 0.8, split_audio = False):
480
- try:
481
- with tqdm(total=10, desc=translations["convert_audio"], ncols=100, unit="a") as pbar:
482
- audio = load_audio(logger, audio_input_path, self.sample_rate, formant_shifting=formant_shifting, formant_qfrency=formant_qfrency, formant_timbre=formant_timbre)
483
- self.checkpointing = checkpointing
484
- audio_max = np.abs(audio).max() / 0.95
485
- if audio_max > 1: audio /= audio_max
486
-
487
- pbar.update(1)
488
- if not self.hubert_model:
489
- models, _, embed_suffix = load_embedders_model(embedder_model, embedders_mode, providers=get_providers())
490
- self.hubert_model = (models.to(self.device).half() if self.config.is_half else models.to(self.device).float()).eval() if embed_suffix in [".pt", ".safetensors"] else models
491
- self.embed_suffix = embed_suffix
492
-
493
- pbar.update(1)
494
- if self.tgt_sr != resample_sr >= self.sample_rate: self.tgt_sr = resample_sr
495
- target_sr = min([8000, 11025, 12000, 16000, 22050, 24000, 32000, 44100, 48000, 96000], key=lambda x: abs(x - self.tgt_sr))
496
-
497
- if split_audio:
498
- chunks = cut(audio, self.sample_rate, db_thresh=-60, min_interval=500)
499
- pbar.total = len(chunks) * 4 + 6
500
- logger.info(f"{translations['split_total']}: {len(chunks)}")
501
- else: chunks = [(audio, 0, 0)]
502
-
503
- converted_chunks = []
504
- pbar.update(1)
505
-
506
- for waveform, start, end in chunks:
507
- converted_chunks.append((start, end, self.vc.pipeline(model=self.hubert_model, net_g=self.net_g, sid=self.sid, audio=waveform, pitch=pitch, f0_method=f0_method, file_index=(index_path.strip().strip('"').strip("\n").strip('"').strip().replace("trained", "added")), index_rate=index_rate, pitch_guidance=self.use_f0, filter_radius=filter_radius, volume_envelope=volume_envelope, version=self.version, protect=protect, hop_length=hop_length, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, suffix=self.suffix, embed_suffix=self.embed_suffix, f0_file=f0_file, f0_onnx=f0_onnx, pbar=pbar)))
508
-
509
- pbar.update(1)
510
- audio_output = restore(converted_chunks, total_len=len(audio), dtype=converted_chunks[0][2].dtype) if split_audio else converted_chunks[0][2]
511
- if target_sr >= self.sample_rate and self.tgt_sr != target_sr: audio_output = librosa.resample(audio_output, orig_sr=self.tgt_sr, target_sr=target_sr, res_type="soxr_vhq")
512
-
513
- pbar.update(1)
514
- if clean_audio:
515
- from main.tools.noisereduce import reduce_noise
516
- audio_output = reduce_noise(y=audio_output, sr=target_sr, prop_decrease=clean_strength, device=self.device)
517
-
518
- sf.write(audio_output_path, audio_output, target_sr, format=export_format)
519
- pbar.update(1)
520
- except Exception as e:
521
- logger.error(translations["error_convert"].format(e=e))
522
- import traceback
523
- logger.debug(traceback.format_exc())
524
-
525
- def get_vc(self, weight_root, sid):
526
- if sid == "" or sid == []:
527
- self.cleanup()
528
- clear_gpu_cache()
529
-
530
- if not self.loaded_model or self.loaded_model != weight_root:
531
- self.loaded_model = weight_root
532
- self.load_model()
533
- if self.cpt is not None: self.setup()
534
-
535
- def cleanup(self):
536
- if self.hubert_model is not None:
537
- del self.net_g, self.n_spk, self.vc, self.hubert_model, self.tgt_sr
538
- self.hubert_model = self.net_g = self.n_spk = self.vc = self.tgt_sr = None
539
- clear_gpu_cache()
540
-
541
- del self.net_g, self.cpt
542
- clear_gpu_cache()
543
- self.cpt = None
544
-
545
- def load_model(self):
546
- if os.path.isfile(self.loaded_model):
547
- if self.loaded_model.endswith(".pth"): self.cpt = torch.load(self.loaded_model, map_location="cpu")
548
- else:
549
- sess_options = onnxruntime.SessionOptions()
550
- sess_options.log_severity_level = 3
551
- self.cpt = onnxruntime.InferenceSession(self.loaded_model, sess_options=sess_options, providers=get_providers())
552
- else: self.cpt = None
553
-
554
- def setup(self):
555
- if self.cpt is not None:
556
- if self.loaded_model.endswith(".pth"):
557
- self.tgt_sr = self.cpt["config"][-1]
558
- self.cpt["config"][-3] = self.cpt["weight"]["emb_g.weight"].shape[0]
559
- self.use_f0 = self.cpt.get("f0", 1)
560
- self.version = self.cpt.get("version", "v1")
561
- self.vocoder = self.cpt.get("vocoder", "Default")
562
- if self.vocoder != "Default": self.config.is_half = False
563
-
564
- self.net_g = Synthesizer(*self.cpt["config"], use_f0=self.use_f0, text_enc_hidden_dim=768 if self.version == "v2" else 256, vocoder=self.vocoder, checkpointing=self.checkpointing)
565
- del self.net_g.enc_q
566
-
567
- self.net_g.load_state_dict(self.cpt["weight"], strict=False)
568
- self.net_g.eval().to(self.device)
569
- self.net_g = (self.net_g.half() if self.config.is_half else self.net_g.float())
570
- self.n_spk = self.cpt["config"][-3]
571
- self.suffix = ".pth"
572
- else:
573
- import json
574
- import onnx
575
-
576
- metadata_dict = None
577
- for prop in onnx.load(self.loaded_model).metadata_props:
578
- if prop.key == "model_info":
579
- metadata_dict = json.loads(prop.value)
580
- break
581
-
582
- self.net_g = self.cpt
583
- self.tgt_sr = metadata_dict.get("sr", 32000)
584
- self.use_f0 = metadata_dict.get("f0", 1)
585
- self.version = metadata_dict.get("version", "v1")
586
- self.suffix = ".onnx"
587
-
588
- self.vc = VC(self.tgt_sr, self.config)
589
-
590
- if __name__ == "__main__": main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/inference/create_dataset.py DELETED
@@ -1,230 +0,0 @@
1
- import os
2
- import sys
3
- import time
4
- import yt_dlp
5
- import shutil
6
- import librosa
7
- import logging
8
- import argparse
9
- import warnings
10
- import logging.handlers
11
-
12
- from soundfile import read, write
13
- from distutils.util import strtobool
14
-
15
- sys.path.append(os.getcwd())
16
-
17
- from main.configs.config import Config
18
- from main.library.algorithm.separator import Separator
19
-
20
- config = Config()
21
- translations = config.translations
22
- dataset_temp = os.path.join("dataset_temp")
23
- logger = logging.getLogger(__name__)
24
-
25
- if logger.hasHandlers(): logger.handlers.clear()
26
- else:
27
- console_handler = logging.StreamHandler()
28
- console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
29
- console_handler.setFormatter(console_formatter)
30
- console_handler.setLevel(logging.INFO)
31
- file_handler = logging.handlers.RotatingFileHandler(os.path.join("assets", "logs", "create_dataset.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
32
- file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
33
- file_handler.setFormatter(file_formatter)
34
- file_handler.setLevel(logging.DEBUG)
35
- logger.addHandler(console_handler)
36
- logger.addHandler(file_handler)
37
- logger.setLevel(logging.DEBUG)
38
-
39
- def parse_arguments():
40
- parser = argparse.ArgumentParser()
41
- parser.add_argument("--input_audio", type=str, required=True)
42
- parser.add_argument("--output_dataset", type=str, default="./dataset")
43
- parser.add_argument("--sample_rate", type=int, default=44100)
44
- parser.add_argument("--clean_dataset", type=lambda x: bool(strtobool(x)), default=False)
45
- parser.add_argument("--clean_strength", type=float, default=0.7)
46
- parser.add_argument("--separator_reverb", type=lambda x: bool(strtobool(x)), default=False)
47
- parser.add_argument("--kim_vocal_version", type=int, default=2)
48
- parser.add_argument("--overlap", type=float, default=0.25)
49
- parser.add_argument("--segments_size", type=int, default=256)
50
- parser.add_argument("--mdx_hop_length", type=int, default=1024)
51
- parser.add_argument("--mdx_batch_size", type=int, default=1)
52
- parser.add_argument("--denoise_mdx", type=lambda x: bool(strtobool(x)), default=False)
53
- parser.add_argument("--skip", type=lambda x: bool(strtobool(x)), default=False)
54
- parser.add_argument("--skip_start_audios", type=str, default="0")
55
- parser.add_argument("--skip_end_audios", type=str, default="0")
56
-
57
- return parser.parse_args()
58
-
59
- def main():
60
- pid_path = os.path.join("assets", "create_dataset_pid.txt")
61
- with open(pid_path, "w") as pid_file:
62
- pid_file.write(str(os.getpid()))
63
-
64
- args = parse_arguments()
65
- input_audio, output_dataset, sample_rate, clean_dataset, clean_strength, separator_reverb, kim_vocal_version, overlap, segments_size, hop_length, batch_size, denoise_mdx, skip, skip_start_audios, skip_end_audios = args.input_audio, args.output_dataset, args.sample_rate, args.clean_dataset, args.clean_strength, args.separator_reverb, args.kim_vocal_version, args.overlap, args.segments_size, args.mdx_hop_length, args.mdx_batch_size, args.denoise_mdx, args.skip, args.skip_start_audios, args.skip_end_audios
66
- log_data = {translations['audio_path']: input_audio, translations['output_path']: output_dataset, translations['sr']: sample_rate, translations['clear_dataset']: clean_dataset, translations['dereveb_audio']: separator_reverb, translations['segments_size']: segments_size, translations['overlap']: overlap, "Hop length": hop_length, translations['batch_size']: batch_size, translations['denoise_mdx']: denoise_mdx, translations['skip']: skip}
67
-
68
- if clean_dataset: log_data[translations['clean_strength']] = clean_strength
69
- if skip:
70
- log_data[translations['skip_start']] = skip_start_audios
71
- log_data[translations['skip_end']] = skip_end_audios
72
-
73
- for key, value in log_data.items():
74
- logger.debug(f"{key}: {value}")
75
-
76
- if kim_vocal_version not in [1, 2]: raise ValueError(translations["version_not_valid"])
77
- start_time = time.time()
78
-
79
- try:
80
- paths = []
81
-
82
- if not os.path.exists(dataset_temp): os.makedirs(dataset_temp, exist_ok=True)
83
- urls = input_audio.replace(", ", ",").split(",")
84
-
85
- for url in urls:
86
- path = downloader(url, urls.index(url))
87
- paths.append(path)
88
-
89
- if skip:
90
- skip_start_audios, skip_end_audios = skip_start_audios.replace(", ", ",").split(","), skip_end_audios.replace(", ", ",").split(",")
91
-
92
- if len(skip_start_audios) < len(paths) or len(skip_end_audios) < len(paths):
93
- logger.warning(translations["skip<audio"])
94
- sys.exit(1)
95
- elif len(skip_start_audios) > len(paths) or len(skip_end_audios) > len(paths):
96
- logger.warning(translations["skip>audio"])
97
- sys.exit(1)
98
- else:
99
- for audio, skip_start_audio, skip_end_audio in zip(paths, skip_start_audios, skip_end_audios):
100
- skip_start(audio, skip_start_audio)
101
- skip_end(audio, skip_end_audio)
102
-
103
- separator_paths = []
104
-
105
- for audio in paths:
106
- vocals = separator_music_main(audio, dataset_temp, segments_size, overlap, denoise_mdx, kim_vocal_version, hop_length, batch_size, sample_rate)
107
- if separator_reverb: vocals = separator_reverb_audio(vocals, dataset_temp, segments_size, overlap, denoise_mdx, hop_length, batch_size, sample_rate)
108
- separator_paths.append(vocals)
109
-
110
- paths = separator_paths
111
-
112
- for audio_path in paths:
113
- data, sample_rate = read(audio_path)
114
- data = librosa.to_mono(data.T)
115
-
116
- if clean_dataset:
117
- from main.tools.noisereduce import reduce_noise
118
- data = reduce_noise(y=data, sr=sample_rate, prop_decrease=clean_strength, device=config.device)
119
-
120
- write(audio_path, data, sample_rate)
121
- except Exception as e:
122
- logger.error(f"{translations['create_dataset_error']}: {e}")
123
- import traceback
124
- logger.error(traceback.format_exc())
125
- finally:
126
- for audio in paths:
127
- shutil.move(audio, output_dataset)
128
-
129
- if os.path.exists(dataset_temp): shutil.rmtree(dataset_temp, ignore_errors=True)
130
-
131
- elapsed_time = time.time() - start_time
132
- if os.path.exists(pid_path): os.remove(pid_path)
133
- logger.info(translations["create_dataset_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
134
-
135
- def downloader(url, name):
136
- with warnings.catch_warnings():
137
- warnings.simplefilter("ignore")
138
-
139
- ydl_opts = {"format": "bestaudio/best", "outtmpl": os.path.join(dataset_temp, f"{name}"), "postprocessors": [{"key": "FFmpegExtractAudio", "preferredcodec": "wav", "preferredquality": "192"}], "no_warnings": True, "noplaylist": True, "noplaylist": True, "verbose": False}
140
- logger.info(f"{translations['starting_download']}: {url}...")
141
-
142
- with yt_dlp.YoutubeDL(ydl_opts) as ydl:
143
- ydl.extract_info(url)
144
- logger.info(f"{translations['download_success']}: {url}")
145
-
146
- return os.path.join(dataset_temp, f"{name}" + ".wav")
147
-
148
- def skip_start(input_file, seconds):
149
- data, sr = read(input_file)
150
- total_duration = len(data) / sr
151
-
152
- if seconds <= 0: logger.warning(translations["=<0"])
153
- elif seconds >= total_duration: logger.warning(translations["skip_warning"].format(seconds=seconds, total_duration=f"{total_duration:.2f}"))
154
- else:
155
- logger.info(f"{translations['skip_start']}: {input_file}...")
156
- write(input_file, data[int(seconds * sr):], sr)
157
-
158
- logger.info(translations["skip_start_audio"].format(input_file=input_file))
159
-
160
- def skip_end(input_file, seconds):
161
- data, sr = read(input_file)
162
- total_duration = len(data) / sr
163
-
164
- if seconds <= 0: logger.warning(translations["=<0"])
165
- elif seconds > total_duration: logger.warning(translations["skip_warning"].format(seconds=seconds, total_duration=f"{total_duration:.2f}"))
166
- else:
167
- logger.info(f"{translations['skip_end']}: {input_file}...")
168
- write(input_file, data[:-int(seconds * sr)], sr)
169
-
170
- logger.info(translations["skip_end_audio"].format(input_file=input_file))
171
-
172
- def separator_music_main(input, output, segments_size, overlap, denoise, version, hop_length, batch_size, sample_rate):
173
- if not os.path.exists(input):
174
- logger.warning(translations["input_not_valid"])
175
- return None
176
-
177
- if not os.path.exists(output):
178
- logger.warning(translations["output_not_valid"])
179
- return None
180
-
181
- model = f"Kim_Vocal_{version}.onnx"
182
- output_separator = separator_main(audio_file=input, model_filename=model, output_format="wav", output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise, sample_rate=sample_rate)
183
-
184
- for f in output_separator:
185
- path = os.path.join(output, f)
186
- if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
187
-
188
- if '_(Instrumental)_' in f: os.rename(path, os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav")
189
- elif '_(Vocals)_' in f:
190
- rename_file = os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav"
191
- os.rename(path, rename_file)
192
-
193
- return rename_file
194
-
195
- def separator_reverb_audio(input, output, segments_size, overlap, denoise, hop_length, batch_size, sample_rate):
196
- if not os.path.exists(input):
197
- logger.warning(translations["input_not_valid"])
198
- return None
199
-
200
- if not os.path.exists(output):
201
- logger.warning(translations["output_not_valid"])
202
- return None
203
-
204
- logger.info(f"{translations['dereverb']}: {input}...")
205
- output_dereverb = separator_main(audio_file=input, model_filename="Reverb_HQ_By_FoxJoy.onnx", output_format="wav", output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=hop_length, mdx_hop_length=batch_size, mdx_enable_denoise=denoise, sample_rate=sample_rate)
206
-
207
- for f in output_dereverb:
208
- path = os.path.join(output, f)
209
- if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
210
-
211
- if '_(Reverb)_' in f: os.rename(path, os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav")
212
- elif '_(No Reverb)_' in f:
213
- rename_file = os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav"
214
- os.rename(path, rename_file)
215
-
216
- logger.info(f"{translations['dereverb_success']}: {rename_file}")
217
- return rename_file
218
-
219
- def separator_main(audio_file=None, model_filename="Kim_Vocal_1.onnx", output_format="wav", output_dir=".", mdx_segment_size=256, mdx_overlap=0.25, mdx_batch_size=1, mdx_hop_length=1024, mdx_enable_denoise=True, sample_rate=44100):
220
- try:
221
- separator = Separator(logger=logger, log_formatter=file_formatter, log_level=logging.INFO, output_dir=output_dir, output_format=output_format, output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=sample_rate, mdx_params={"hop_length": mdx_hop_length, "segment_size": mdx_segment_size, "overlap": mdx_overlap, "batch_size": mdx_batch_size, "enable_denoise": mdx_enable_denoise})
222
- separator.load_model(model_filename=model_filename)
223
- return separator.separate(audio_file)
224
- except:
225
- logger.debug(translations["default_setting"])
226
- separator = Separator(logger=logger, log_formatter=file_formatter, log_level=logging.INFO, output_dir=output_dir, output_format=output_format, output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=44100, mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": mdx_enable_denoise})
227
- separator.load_model(model_filename=model_filename)
228
- return separator.separate(audio_file)
229
-
230
- if __name__ == "__main__": main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/inference/create_index.py DELETED
@@ -1,90 +0,0 @@
1
- import os
2
- import sys
3
- import faiss
4
- import logging
5
- import argparse
6
- import logging.handlers
7
-
8
- import numpy as np
9
-
10
- from multiprocessing import cpu_count
11
- from sklearn.cluster import MiniBatchKMeans
12
-
13
- sys.path.append(os.getcwd())
14
-
15
- from main.configs.config import Config
16
- translations = Config().translations
17
-
18
- def parse_arguments():
19
- parser = argparse.ArgumentParser()
20
- parser.add_argument("--model_name", type=str, required=True)
21
- parser.add_argument("--rvc_version", type=str, default="v2")
22
- parser.add_argument("--index_algorithm", type=str, default="Auto")
23
-
24
- return parser.parse_args()
25
-
26
- def main():
27
- args = parse_arguments()
28
- exp_dir = os.path.join("assets", "logs", args.model_name)
29
- version, index_algorithm = args.rvc_version, args.index_algorithm
30
- logger = logging.getLogger(__name__)
31
-
32
- if logger.hasHandlers(): logger.handlers.clear()
33
- else:
34
- console_handler = logging.StreamHandler()
35
- console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
36
- console_handler.setFormatter(console_formatter)
37
- console_handler.setLevel(logging.INFO)
38
- file_handler = logging.handlers.RotatingFileHandler(os.path.join(exp_dir, "create_index.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
39
- file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
40
- file_handler.setFormatter(file_formatter)
41
- file_handler.setLevel(logging.DEBUG)
42
- logger.addHandler(console_handler)
43
- logger.addHandler(file_handler)
44
- logger.setLevel(logging.DEBUG)
45
-
46
- log_data = {translations['modelname']: args.model_name, translations['model_path']: exp_dir, translations['training_version']: version, translations['index_algorithm_info']: index_algorithm}
47
- for key, value in log_data.items():
48
- logger.debug(f"{key}: {value}")
49
-
50
- try:
51
- npys = []
52
- feature_dir = os.path.join(exp_dir, f"{version}_extracted")
53
- model_name = os.path.basename(exp_dir)
54
-
55
- for name in sorted(os.listdir(feature_dir)):
56
- npys.append(np.load(os.path.join(feature_dir, name)))
57
-
58
- big_npy = np.concatenate(npys, axis=0)
59
- big_npy_idx = np.arange(big_npy.shape[0])
60
- np.random.shuffle(big_npy_idx)
61
- big_npy = big_npy[big_npy_idx]
62
-
63
- if big_npy.shape[0] > 2e5 and (index_algorithm == "Auto" or index_algorithm == "KMeans"): big_npy = (MiniBatchKMeans(n_clusters=10000, verbose=True, batch_size=256 * cpu_count(), compute_labels=False, init="random").fit(big_npy).cluster_centers_)
64
- np.save(os.path.join(exp_dir, "total_fea.npy"), big_npy)
65
-
66
- n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39)
67
- index_trained = faiss.index_factory(256 if version == "v1" else 768, f"IVF{n_ivf},Flat")
68
- index_ivf_trained = faiss.extract_index_ivf(index_trained)
69
- index_ivf_trained.nprobe = 1
70
- index_trained.train(big_npy)
71
- faiss.write_index(index_trained, os.path.join(exp_dir, f"trained_IVF{n_ivf}_Flat_nprobe_{index_ivf_trained.nprobe}_{model_name}_{version}.index"))
72
-
73
- index_added = faiss.index_factory(256 if version == "v1" else 768, f"IVF{n_ivf},Flat")
74
- index_ivf_added = faiss.extract_index_ivf(index_added)
75
- index_ivf_added.nprobe = 1
76
- index_added.train(big_npy)
77
- batch_size_add = 8192
78
-
79
- for i in range(0, big_npy.shape[0], batch_size_add):
80
- index_added.add(big_npy[i : i + batch_size_add])
81
-
82
- index_filepath_added = os.path.join(exp_dir, f"added_IVF{n_ivf}_Flat_nprobe_{index_ivf_added.nprobe}_{model_name}_{version}.index")
83
- faiss.write_index(index_added, index_filepath_added)
84
- logger.info(f"{translations['save_index']} '{index_filepath_added}'")
85
- except Exception as e:
86
- logger.error(f"{translations['create_index_error']}: {e}")
87
- import traceback
88
- logger.debug(traceback.format_exc())
89
-
90
- if __name__ == "__main__": main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/inference/extract.py DELETED
@@ -1,360 +0,0 @@
1
- import os
2
- import re
3
- import sys
4
- import time
5
- import tqdm
6
- import torch
7
- import shutil
8
- import logging
9
- import argparse
10
- import warnings
11
- import onnxruntime
12
- import logging.handlers
13
-
14
- import numpy as np
15
- import soundfile as sf
16
- import torch.nn.functional as F
17
-
18
- from random import shuffle
19
- from distutils.util import strtobool
20
- from concurrent.futures import ThreadPoolExecutor, as_completed
21
-
22
- sys.path.append(os.getcwd())
23
-
24
- from main.configs.config import Config
25
- from main.library.utils import check_predictors, check_embedders, load_audio, load_embedders_model
26
-
27
- logger = logging.getLogger(__name__)
28
- config = Config()
29
- translations = config.translations
30
- logger.propagate = False
31
-
32
- warnings.filterwarnings("ignore")
33
- for l in ["torch", "faiss", "httpx", "httpcore", "faiss.loader", "numba.core", "urllib3", "matplotlib"]:
34
- logging.getLogger(l).setLevel(logging.ERROR)
35
-
36
- def parse_arguments():
37
- parser = argparse.ArgumentParser()
38
- parser.add_argument("--model_name", type=str, required=True)
39
- parser.add_argument("--rvc_version", type=str, default="v2")
40
- parser.add_argument("--f0_method", type=str, default="rmvpe")
41
- parser.add_argument("--pitch_guidance", type=lambda x: bool(strtobool(x)), default=True)
42
- parser.add_argument("--hop_length", type=int, default=128)
43
- parser.add_argument("--cpu_cores", type=int, default=2)
44
- parser.add_argument("--gpu", type=str, default="-")
45
- parser.add_argument("--sample_rate", type=int, required=True)
46
- parser.add_argument("--embedder_model", type=str, default="contentvec_base")
47
- parser.add_argument("--f0_onnx", type=lambda x: bool(strtobool(x)), default=False)
48
- parser.add_argument("--embedders_mode", type=str, default="fairseq")
49
-
50
- return parser.parse_args()
51
-
52
- def generate_config(rvc_version, sample_rate, model_path):
53
- config_save_path = os.path.join(model_path, "config.json")
54
- if not os.path.exists(config_save_path): shutil.copy(os.path.join("main", "configs", rvc_version, f"{sample_rate}.json"), config_save_path)
55
-
56
- def generate_filelist(pitch_guidance, model_path, rvc_version, sample_rate, embedders_mode = "fairseq"):
57
- gt_wavs_dir, feature_dir = os.path.join(model_path, "sliced_audios"), os.path.join(model_path, f"{rvc_version}_extracted")
58
- f0_dir, f0nsf_dir = None, None
59
-
60
- if pitch_guidance: f0_dir, f0nsf_dir = os.path.join(model_path, "f0"), os.path.join(model_path, "f0_voiced")
61
-
62
- gt_wavs_files, feature_files = set(name.split(".")[0] for name in os.listdir(gt_wavs_dir)), set(name.split(".")[0] for name in os.listdir(feature_dir))
63
- names = gt_wavs_files & feature_files & set(name.split(".")[0] for name in os.listdir(f0_dir)) & set(name.split(".")[0] for name in os.listdir(f0nsf_dir)) if pitch_guidance else gt_wavs_files & feature_files
64
-
65
- options = []
66
- mute_base_path = os.path.join("assets", "logs", "mute" if embedders_mode != "spin" else "mute_spin")
67
-
68
- for name in names:
69
- options.append(f"{gt_wavs_dir}/{name}.wav|{feature_dir}/{name}.npy|{f0_dir}/{name}.wav.npy|{f0nsf_dir}/{name}.wav.npy|0" if pitch_guidance else f"{gt_wavs_dir}/{name}.wav|{feature_dir}/{name}.npy|0")
70
-
71
- mute_audio_path, mute_feature_path = os.path.join(mute_base_path, "sliced_audios", f"mute{sample_rate}.wav"), os.path.join(mute_base_path, f"{rvc_version}_extracted", "mute.npy")
72
- for _ in range(2):
73
- options.append(f"{mute_audio_path}|{mute_feature_path}|{os.path.join(mute_base_path, 'f0', 'mute.wav.npy')}|{os.path.join(mute_base_path, 'f0_voiced', 'mute.wav.npy')}|0" if pitch_guidance else f"{mute_audio_path}|{mute_feature_path}|0")
74
-
75
- shuffle(options)
76
- with open(os.path.join(model_path, "filelist.txt"), "w") as f:
77
- f.write("\n".join(options))
78
-
79
- def setup_paths(exp_dir, version = None):
80
- wav_path = os.path.join(exp_dir, "sliced_audios_16k")
81
-
82
- if version:
83
- out_path = os.path.join(exp_dir, f"{version}_extracted")
84
- os.makedirs(out_path, exist_ok=True)
85
- return wav_path, out_path
86
- else:
87
- output_root1, output_root2 = os.path.join(exp_dir, "f0"), os.path.join(exp_dir, "f0_voiced")
88
- os.makedirs(output_root1, exist_ok=True); os.makedirs(output_root2, exist_ok=True)
89
- return wav_path, output_root1, output_root2
90
-
91
- def read_wave(wav_path, normalize = False, is_half = False):
92
- wav, sr = sf.read(wav_path, dtype=np.float32)
93
- assert sr == 16000, translations["sr_not_16000"]
94
-
95
- feats = torch.from_numpy(wav)
96
- if feats.dim() == 2: feats = feats.mean(-1)
97
- feats = feats.view(1, -1)
98
-
99
- if normalize: feats = F.layer_norm(feats, feats.shape)
100
- return feats.half() if is_half else feats.float()
101
-
102
- def get_device(gpu_index):
103
- try:
104
- index = int(gpu_index)
105
- if index < torch.cuda.device_count(): return f"cuda:{index}"
106
- else: logger.warning(translations["gpu_not_valid"])
107
- except ValueError:
108
- logger.warning(translations["gpu_not_valid"])
109
- return "cpu"
110
-
111
- def get_providers():
112
- ort_providers = onnxruntime.get_available_providers()
113
-
114
- if "CUDAExecutionProvider" in ort_providers: providers = ["CUDAExecutionProvider"]
115
- elif "CoreMLExecutionProvider" in ort_providers: providers = ["CoreMLExecutionProvider"]
116
- else: providers = ["CPUExecutionProvider"]
117
-
118
- return providers
119
-
120
- class FeatureInput:
121
- def __init__(self, sample_rate=16000, hop_size=160, is_half=False, device=config.device):
122
- self.fs = sample_rate
123
- self.hop = hop_size
124
- self.f0_bin = 256
125
- self.f0_max = 1100.0
126
- self.f0_min = 50.0
127
- self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700)
128
- self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700)
129
- self.device = device
130
- self.is_half = is_half
131
-
132
- def compute_f0_hybrid(self, methods_str, np_arr, hop_length, f0_onnx):
133
- methods_str = re.search("hybrid\[(.+)\]", methods_str)
134
- if methods_str: methods = [method.strip() for method in methods_str.group(1).split("+")]
135
- f0_computation_stack, resampled_stack = [], []
136
- logger.debug(translations["hybrid_methods"].format(methods=methods))
137
-
138
- for method in methods:
139
- f0 = None
140
- f0_methods = {"pm": lambda: self.get_pm(np_arr), "dio": lambda: self.get_pyworld(np_arr, "dio"), "mangio-crepe-full": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "full", onnx=f0_onnx), "mangio-crepe-large": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "large", onnx=f0_onnx), "mangio-crepe-medium": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "medium", onnx=f0_onnx), "mangio-crepe-small": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "small", onnx=f0_onnx), "mangio-crepe-tiny": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "tiny", onnx=f0_onnx), "crepe-full": lambda: self.get_crepe(np_arr, "full", onnx=f0_onnx), "crepe-large": lambda: self.get_crepe(np_arr, "large", onnx=f0_onnx), "crepe-medium": lambda: self.get_crepe(np_arr, "medium", onnx=f0_onnx), "crepe-small": lambda: self.get_crepe(np_arr, "small", onnx=f0_onnx), "crepe-tiny": lambda: self.get_crepe(np_arr, "tiny", onnx=f0_onnx), "fcpe": lambda: self.get_fcpe(np_arr, int(hop_length), onnx=f0_onnx), "fcpe-legacy": lambda: self.get_fcpe(np_arr, int(hop_length), legacy=True, onnx=f0_onnx), "rmvpe": lambda: self.get_rmvpe(np_arr, onnx=f0_onnx), "rmvpe-legacy": lambda: self.get_rmvpe(np_arr, legacy=True, onnx=f0_onnx), "harvest": lambda: self.get_pyworld(np_arr, "harvest"), "swipe": lambda: self.get_swipe(np_arr), "yin": lambda: self.get_yin(np_arr, int(hop_length), mode="yin"), "pyin": lambda: self.get_yin(np_arr, int(hop_length), mode="pyin")}
141
- f0 = f0_methods.get(method, lambda: ValueError(translations["method_not_valid"]))()
142
- f0_computation_stack.append(f0)
143
-
144
- for f0 in f0_computation_stack:
145
- resampled_stack.append(np.interp(np.linspace(0, len(f0), (np_arr.size // self.hop)), np.arange(len(f0)), f0))
146
-
147
- return resampled_stack[0] if len(resampled_stack) == 1 else np.nanmedian(np.vstack(resampled_stack), axis=0)
148
-
149
- def compute_f0(self, np_arr, f0_method, hop_length, f0_onnx=False):
150
- f0_methods = {"pm": lambda: self.get_pm(np_arr), "dio": lambda: self.get_pyworld(np_arr, "dio"), "mangio-crepe-full": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "full", onnx=f0_onnx), "mangio-crepe-large": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "large", onnx=f0_onnx), "mangio-crepe-medium": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "medium", onnx=f0_onnx), "mangio-crepe-small": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "small", onnx=f0_onnx), "mangio-crepe-tiny": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "tiny", onnx=f0_onnx), "crepe-full": lambda: self.get_crepe(np_arr, "full", onnx=f0_onnx), "crepe-large": lambda: self.get_crepe(np_arr, "large", onnx=f0_onnx), "crepe-medium": lambda: self.get_crepe(np_arr, "medium", onnx=f0_onnx), "crepe-small": lambda: self.get_crepe(np_arr, "small", onnx=f0_onnx), "crepe-tiny": lambda: self.get_crepe(np_arr, "tiny", onnx=f0_onnx), "fcpe": lambda: self.get_fcpe(np_arr, int(hop_length), onnx=f0_onnx), "fcpe-legacy": lambda: self.get_fcpe(np_arr, int(hop_length), legacy=True, onnx=f0_onnx), "rmvpe": lambda: self.get_rmvpe(np_arr, onnx=f0_onnx), "rmvpe-legacy": lambda: self.get_rmvpe(np_arr, legacy=True, onnx=f0_onnx), "harvest": lambda: self.get_pyworld(np_arr, "harvest"), "swipe": lambda: self.get_swipe(np_arr), "yin": lambda: self.get_yin(np_arr, int(hop_length), mode="yin"), "pyin": lambda: self.get_yin(np_arr, int(hop_length), mode="pyin")}
151
- return self.compute_f0_hybrid(f0_method, np_arr, int(hop_length), f0_onnx) if "hybrid" in f0_method else f0_methods.get(f0_method, lambda: ValueError(translations["method_not_valid"]))()
152
-
153
- def get_pm(self, x):
154
- import parselmouth
155
-
156
- f0 = (parselmouth.Sound(x, self.fs).to_pitch_ac(time_step=(160 / 16000 * 1000) / 1000, voicing_threshold=0.6, pitch_floor=50, pitch_ceiling=1100).selected_array["frequency"])
157
- pad_size = ((x.size // self.hop) - len(f0) + 1) // 2
158
-
159
- if pad_size > 0 or (x.size // self.hop) - len(f0) - pad_size > 0: f0 = np.pad(f0, [[pad_size, (x.size // self.hop) - len(f0) - pad_size]], mode="constant")
160
- return f0
161
-
162
- def get_mangio_crepe(self, x, hop_length, model="full", onnx=False):
163
- from main.library.predictors.CREPE import predict
164
-
165
- audio = torch.from_numpy(x.astype(np.float32)).to(self.device)
166
- audio /= torch.quantile(torch.abs(audio), 0.999)
167
- audio = audio.unsqueeze(0)
168
- source = predict(audio, self.fs, hop_length, self.f0_min, self.f0_max, model=model, batch_size=hop_length * 2, device=self.device, pad=True, providers=get_providers(), onnx=onnx).squeeze(0).cpu().float().numpy()
169
- source[source < 0.001] = np.nan
170
-
171
- return np.nan_to_num(np.interp(np.arange(0, len(source) * (x.size // self.hop), len(source)) / (x.size // self.hop), np.arange(0, len(source)), source))
172
-
173
- def get_crepe(self, x, model="full", onnx=False):
174
- from main.library.predictors.CREPE import predict, mean, median
175
-
176
- f0, pd = predict(torch.tensor(np.copy(x))[None].float(), self.fs, 160, self.f0_min, self.f0_max, model, batch_size=512, device=self.device, return_periodicity=True, providers=get_providers(), onnx=onnx)
177
- f0, pd = mean(f0, 3), median(pd, 3)
178
- f0[pd < 0.1] = 0
179
-
180
- return f0[0].cpu().numpy()
181
-
182
- def get_fcpe(self, x, hop_length, legacy=False, onnx=False):
183
- from main.library.predictors.FCPE import FCPE
184
-
185
- model_fcpe = FCPE(os.path.join("assets", "models", "predictors", ("fcpe_legacy" if legacy else"fcpe") + (".onnx" if onnx else ".pt")), hop_length=int(hop_length), f0_min=int(self.f0_min), f0_max=int(self.f0_max), dtype=torch.float32, device=self.device, sample_rate=self.fs, threshold=0.03 if legacy else 0.006, providers=get_providers(), onnx=onnx, legacy=legacy)
186
- f0 = model_fcpe.compute_f0(x, p_len=(x.size // self.hop))
187
-
188
- del model_fcpe
189
- return f0
190
-
191
- def get_rmvpe(self, x, legacy=False, onnx=False):
192
- from main.library.predictors.RMVPE import RMVPE
193
-
194
- rmvpe_model = RMVPE(os.path.join("assets", "models", "predictors", "rmvpe" + (".onnx" if onnx else ".pt")), is_half=self.is_half, device=self.device, onnx=onnx, providers=get_providers())
195
- f0 = rmvpe_model.infer_from_audio_with_pitch(x, thred=0.03, f0_min=self.f0_min, f0_max=self.f0_max) if legacy else rmvpe_model.infer_from_audio(x, thred=0.03)
196
-
197
- del rmvpe_model
198
- return f0
199
-
200
- def get_pyworld(self, x, model="harvest"):
201
- from main.library.predictors.WORLD_WRAPPER import PYWORLD
202
-
203
- pw = PYWORLD()
204
- x = x.astype(np.double)
205
-
206
- if model == "harvest": f0, t = pw.harvest(x, fs=self.fs, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=1000 * self.hop / self.fs)
207
- elif model == "dio": f0, t = pw.dio(x, fs=self.fs, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=1000 * self.hop / self.fs)
208
- else: raise ValueError(translations["method_not_valid"])
209
-
210
- return pw.stonemask(x, self.fs, t, f0)
211
-
212
- def get_swipe(self, x):
213
- from main.library.predictors.SWIPE import swipe
214
-
215
- f0, _ = swipe(x.astype(np.float32), self.fs, f0_floor=self.f0_min, f0_ceil=self.f0_max, frame_period=1000 * self.hop / self.fs)
216
- return f0
217
-
218
- def get_yin(self, x, hop_length, mode="yin"):
219
- import librosa
220
-
221
- source = np.array(librosa.yin(x.astype(np.float32), sr=self.fs, fmin=self.f0_min, fmax=self.f0_max, hop_length=hop_length) if mode == "yin" else librosa.pyin(x.astype(np.float32), fmin=self.f0_min, fmax=self.f0_max, sr=self.fs, hop_length=hop_length)[0])
222
- source[source < 0.001] = np.nan
223
- return np.nan_to_num(np.interp(np.arange(0, len(source) * (x.size // self.hop), len(source)) / (x.size // self.hop), np.arange(0, len(source)), source))
224
-
225
- def coarse_f0(self, f0):
226
- return np.rint(np.clip(((1127 * np.log(1 + f0 / 700)) - self.f0_mel_min) * (self.f0_bin - 2) / (self.f0_mel_max - self.f0_mel_min) + 1, 1, self.f0_bin - 1)).astype(int)
227
-
228
- def process_file(self, file_info, f0_method, hop_length, f0_onnx):
229
- inp_path, opt_path1, opt_path2, np_arr = file_info
230
- if os.path.exists(opt_path1 + ".npy") and os.path.exists(opt_path2 + ".npy"): return
231
-
232
- try:
233
- feature_pit = self.compute_f0(np_arr, f0_method, hop_length, f0_onnx)
234
- if isinstance(feature_pit, tuple): feature_pit = feature_pit[0]
235
- np.save(opt_path2, feature_pit, allow_pickle=False)
236
- np.save(opt_path1, self.coarse_f0(feature_pit), allow_pickle=False)
237
- except Exception as e:
238
- raise RuntimeError(f"{translations['extract_file_error']} {inp_path}: {e}")
239
-
240
- def process_files(self, files, f0_method, hop_length, f0_onnx, pbar):
241
- for file_info in files:
242
- self.process_file(file_info, f0_method, hop_length, f0_onnx)
243
- pbar.update()
244
-
245
- def run_pitch_extraction(exp_dir, f0_method, hop_length, num_processes, gpus, f0_onnx, is_half):
246
- input_root, *output_roots = setup_paths(exp_dir)
247
- output_root1, output_root2 = output_roots if len(output_roots) == 2 else (output_roots[0], None)
248
-
249
- paths = [(os.path.join(input_root, name), os.path.join(output_root1, name) if output_root1 else None, os.path.join(output_root2, name) if output_root2 else None, load_audio(logger, os.path.join(input_root, name), 16000)) for name in sorted(os.listdir(input_root)) if "spec" not in name]
250
- logger.info(translations["extract_f0_method"].format(num_processes=num_processes, f0_method=f0_method))
251
-
252
- start_time = time.time()
253
- gpus = gpus.split("-")
254
- process_partials = []
255
-
256
- pbar = tqdm.tqdm(total=len(paths), ncols=100, unit="p")
257
- for idx, gpu in enumerate(gpus):
258
- feature_input = FeatureInput(device=get_device(gpu) if gpu != "" else "cpu", is_half=is_half)
259
- process_partials.append((feature_input, paths[idx::len(gpus)]))
260
-
261
- with ThreadPoolExecutor(max_workers=num_processes) as executor:
262
- for future in as_completed([executor.submit(FeatureInput.process_files, feature_input, part_paths, f0_method, hop_length, f0_onnx, pbar) for feature_input, part_paths in process_partials]):
263
- pbar.update(1)
264
- logger.debug(pbar.format_meter(pbar.n, pbar.total, pbar.format_dict["elapsed"]))
265
- future.result()
266
-
267
- pbar.close()
268
- logger.info(translations["extract_f0_success"].format(elapsed_time=f"{(time.time() - start_time):.2f}"))
269
-
270
- def extract_features(model, feats, version):
271
- return torch.as_tensor(model.run([model.get_outputs()[0].name, model.get_outputs()[1].name], {"feats": feats.detach().cpu().numpy()})[0 if version == "v1" else 1], dtype=torch.float32, device=feats.device)
272
-
273
- def process_file_embedding(file, wav_path, out_path, model, device, version, saved_cfg, embed_suffix, is_half):
274
- out_file_path = os.path.join(out_path, file.replace("wav", "npy"))
275
- if os.path.exists(out_file_path): return
276
- feats = read_wave(os.path.join(wav_path, file), normalize=saved_cfg.task.normalize if saved_cfg else False, is_half=is_half).to(device)
277
-
278
- with torch.no_grad():
279
- if embed_suffix == ".pt":
280
- model = model.to(device).to(torch.float16 if is_half else torch.float32).eval()
281
- logits = model.extract_features(**{"source": feats, "padding_mask": torch.BoolTensor(feats.shape).fill_(False).to(device), "output_layer": 9 if version == "v1" else 12})
282
- feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
283
- elif embed_suffix == ".onnx": feats = extract_features(model, feats, version).to(device)
284
- elif embed_suffix == ".safetensors":
285
- model = model.to(device).to(torch.float16 if is_half else torch.float32).eval()
286
- logits = model(feats)["last_hidden_state"]
287
- feats = (model.final_proj(logits[0]).unsqueeze(0) if version == "v1" else logits)
288
- else: raise ValueError(translations["option_not_valid"])
289
-
290
- feats = feats.squeeze(0).float().cpu().numpy()
291
- if not np.isnan(feats).any(): np.save(out_file_path, feats, allow_pickle=False)
292
- else: logger.warning(f"{file} {translations['NaN']}")
293
-
294
- def run_embedding_extraction(exp_dir, version, gpus, embedder_model, embedders_mode, is_half):
295
- wav_path, out_path = setup_paths(exp_dir, version)
296
- logger.info(translations["start_extract_hubert"])
297
- start_time = time.time()
298
- models, saved_cfg, embed_suffix = load_embedders_model(embedder_model, embedders_mode, providers=get_providers())
299
- devices = [get_device(gpu) for gpu in (gpus.split("-") if gpus != "-" else ["cpu"])]
300
- paths = sorted([file for file in os.listdir(wav_path) if file.endswith(".wav")])
301
-
302
- if not paths:
303
- logger.warning(translations["not_found_audio_file"])
304
- sys.exit(1)
305
-
306
- pbar = tqdm.tqdm(total=len(paths) * len(devices), ncols=100, unit="p")
307
- for task in [(file, wav_path, out_path, models, device, version, saved_cfg, embed_suffix, is_half) for file in paths for device in devices]:
308
- try:
309
- process_file_embedding(*task)
310
- except Exception as e:
311
- raise RuntimeError(f"{translations['process_error']} {task[0]}: {e}")
312
-
313
- pbar.update(1)
314
- logger.debug(pbar.format_meter(pbar.n, pbar.total, pbar.format_dict["elapsed"]))
315
-
316
- pbar.close()
317
- logger.info(translations["extract_hubert_success"].format(elapsed_time=f"{(time.time() - start_time):.2f}"))
318
-
319
- def main():
320
- args = parse_arguments()
321
- exp_dir = os.path.join("assets", "logs", args.model_name)
322
- f0_method, hop_length, num_processes, gpus, version, pitch_guidance, sample_rate, embedder_model, f0_onnx, embedders_mode = args.f0_method, args.hop_length, args.cpu_cores, args.gpu, args.rvc_version, args.pitch_guidance, args.sample_rate, args.embedder_model, args.f0_onnx, args.embedders_mode
323
-
324
- check_predictors(f0_method, f0_onnx); check_embedders(embedder_model, embedders_mode)
325
- if logger.hasHandlers(): logger.handlers.clear()
326
- else:
327
- console_handler = logging.StreamHandler()
328
- console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
329
- console_handler.setFormatter(console_formatter)
330
- console_handler.setLevel(logging.INFO)
331
- file_handler = logging.handlers.RotatingFileHandler(os.path.join(exp_dir, "extract.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
332
- file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
333
- file_handler.setFormatter(file_formatter)
334
- file_handler.setLevel(logging.DEBUG)
335
- logger.addHandler(console_handler)
336
- logger.addHandler(file_handler)
337
- logger.setLevel(logging.DEBUG)
338
-
339
- log_data = {translations['modelname']: args.model_name, translations['export_process']: exp_dir, translations['f0_method']: f0_method, translations['pretrain_sr']: sample_rate, translations['cpu_core']: num_processes, "Gpu": gpus, "Hop length": hop_length, translations['training_version']: version, translations['extract_f0']: pitch_guidance, translations['hubert_model']: embedder_model, translations["f0_onnx_mode"]: f0_onnx, translations["embed_mode"]: embedders_mode}
340
- for key, value in log_data.items():
341
- logger.debug(f"{key}: {value}")
342
-
343
- pid_path = os.path.join(exp_dir, "extract_pid.txt")
344
- with open(pid_path, "w") as pid_file:
345
- pid_file.write(str(os.getpid()))
346
-
347
- try:
348
- run_pitch_extraction(exp_dir, f0_method, hop_length, num_processes, gpus, f0_onnx, config.is_half)
349
- run_embedding_extraction(exp_dir, version, gpus, embedder_model, embedders_mode, config.is_half)
350
- generate_config(version, sample_rate, exp_dir)
351
- generate_filelist(pitch_guidance, exp_dir, version, sample_rate, embedders_mode)
352
- except Exception as e:
353
- logger.error(f"{translations['extract_error']}: {e}")
354
- import traceback
355
- logger.debug(traceback.format_exc())
356
-
357
- if os.path.exists(pid_path): os.remove(pid_path)
358
- logger.info(f"{translations['extract_success']} {args.model_name}.")
359
-
360
- if __name__ == "__main__": main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/inference/preprocess.py DELETED
@@ -1,270 +0,0 @@
1
- import os
2
- import sys
3
- import time
4
- import logging
5
- import librosa
6
- import argparse
7
- import logging.handlers
8
-
9
- import numpy as np
10
- import torch.multiprocessing as mp
11
-
12
- from tqdm import tqdm
13
- from scipy import signal
14
- from scipy.io import wavfile
15
- from distutils.util import strtobool
16
- from concurrent.futures import ProcessPoolExecutor, as_completed
17
-
18
- sys.path.append(os.getcwd())
19
-
20
- from main.library.utils import load_audio
21
- from main.configs.config import Config
22
-
23
- logger = logging.getLogger(__name__)
24
- for l in ["numba.core.byteflow", "numba.core.ssa", "numba.core.interpreter"]:
25
- logging.getLogger(l).setLevel(logging.ERROR)
26
-
27
- OVERLAP, MAX_AMPLITUDE, ALPHA, HIGH_PASS_CUTOFF, SAMPLE_RATE_16K = 0.3, 0.9, 0.75, 48, 16000
28
-
29
- config = Config()
30
- translations = config.translations
31
-
32
- def parse_arguments():
33
- parser = argparse.ArgumentParser()
34
- parser.add_argument("--model_name", type=str, required=True)
35
- parser.add_argument("--dataset_path", type=str, default="./dataset")
36
- parser.add_argument("--sample_rate", type=int, required=True)
37
- parser.add_argument("--cpu_cores", type=int, default=2)
38
- parser.add_argument("--cut_preprocess", type=lambda x: bool(strtobool(x)), default=True)
39
- parser.add_argument("--process_effects", type=lambda x: bool(strtobool(x)), default=False)
40
- parser.add_argument("--clean_dataset", type=lambda x: bool(strtobool(x)), default=False)
41
- parser.add_argument("--clean_strength", type=float, default=0.7)
42
-
43
- return parser.parse_args()
44
-
45
- class Slicer:
46
- def __init__(self, sr, threshold = -40.0, min_length = 5000, min_interval = 300, hop_size = 20, max_sil_kept = 5000):
47
- if not min_length >= min_interval >= hop_size: raise ValueError(translations["min_length>=min_interval>=hop_size"])
48
- if not max_sil_kept >= hop_size: raise ValueError(translations["max_sil_kept>=hop_size"])
49
- min_interval = sr * min_interval / 1000
50
- self.threshold = 10 ** (threshold / 20.0)
51
- self.hop_size = round(sr * hop_size / 1000)
52
- self.win_size = min(round(min_interval), 4 * self.hop_size)
53
- self.min_length = round(sr * min_length / 1000 / self.hop_size)
54
- self.min_interval = round(min_interval / self.hop_size)
55
- self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
56
-
57
- def _apply_slice(self, waveform, begin, end):
58
- start_idx = begin * self.hop_size
59
-
60
- if len(waveform.shape) > 1: return waveform[:, start_idx:min(waveform.shape[1], end * self.hop_size)]
61
- else: return waveform[start_idx:min(waveform.shape[0], end * self.hop_size)]
62
-
63
- def slice(self, waveform):
64
- samples = waveform.mean(axis=0) if len(waveform.shape) > 1 else waveform
65
- if samples.shape[0] <= self.min_length: return [waveform]
66
- rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
67
- sil_tags = []
68
- silence_start, clip_start = None, 0
69
-
70
- for i, rms in enumerate(rms_list):
71
- if rms < self.threshold:
72
- if silence_start is None: silence_start = i
73
- continue
74
-
75
- if silence_start is None: continue
76
-
77
- is_leading_silence = silence_start == 0 and i > self.max_sil_kept
78
- need_slice_middle = (i - silence_start >= self.min_interval and i - clip_start >= self.min_length)
79
-
80
- if not is_leading_silence and not need_slice_middle:
81
- silence_start = None
82
- continue
83
-
84
- if i - silence_start <= self.max_sil_kept:
85
- pos = rms_list[silence_start : i + 1].argmin() + silence_start
86
- sil_tags.append((0, pos) if silence_start == 0 else (pos, pos))
87
- clip_start = pos
88
- elif i - silence_start <= self.max_sil_kept * 2:
89
- pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
90
- pos += i - self.max_sil_kept
91
- pos_r = (rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept)
92
-
93
- if silence_start == 0:
94
- sil_tags.append((0, pos_r))
95
- clip_start = pos_r
96
- else:
97
- sil_tags.append((min((rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start), pos), max(pos_r, pos)))
98
- clip_start = max(pos_r, pos)
99
- else:
100
- pos_r = (rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept)
101
- sil_tags.append((0, pos_r) if silence_start == 0 else ((rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start), pos_r))
102
- clip_start = pos_r
103
-
104
- silence_start = None
105
- total_frames = rms_list.shape[0]
106
- if (silence_start is not None and total_frames - silence_start >= self.min_interval): sil_tags.append((rms_list[silence_start : min(total_frames, silence_start + self.max_sil_kept) + 1].argmin() + silence_start, total_frames + 1))
107
-
108
- if not sil_tags: return [waveform]
109
- else:
110
- chunks = []
111
- if sil_tags[0][0] > 0: chunks.append(self._apply_slice(waveform, 0, sil_tags[0][0]))
112
-
113
- for i in range(len(sil_tags) - 1):
114
- chunks.append(self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]))
115
-
116
- if sil_tags[-1][1] < total_frames: chunks.append(self._apply_slice(waveform, sil_tags[-1][1], total_frames))
117
- return chunks
118
-
119
- def get_rms(y, frame_length=2048, hop_length=512, pad_mode="constant"):
120
- y = np.pad(y, (int(frame_length // 2), int(frame_length // 2)), mode=pad_mode)
121
- axis = -1
122
- x_shape_trimmed = list(y.shape)
123
- x_shape_trimmed[axis] -= frame_length - 1
124
- xw = np.moveaxis(np.lib.stride_tricks.as_strided(y, shape=tuple(x_shape_trimmed) + tuple([frame_length]), strides=y.strides + tuple([y.strides[axis]])), -1, axis - 1 if axis < 0 else axis + 1)
125
- slices = [slice(None)] * xw.ndim
126
- slices[axis] = slice(0, None, hop_length)
127
- return np.sqrt(np.mean(np.abs(xw[tuple(slices)]) ** 2, axis=-2, keepdims=True))
128
-
129
- class PreProcess:
130
- def __init__(self, sr, exp_dir, per):
131
- self.slicer = Slicer(sr=sr, threshold=-42, min_length=1500, min_interval=400, hop_size=15, max_sil_kept=500)
132
- self.sr = sr
133
- self.b_high, self.a_high = signal.butter(N=5, Wn=HIGH_PASS_CUTOFF, btype="high", fs=self.sr)
134
- self.per = per
135
- self.exp_dir = exp_dir
136
- self.device = "cpu"
137
- self.gt_wavs_dir = os.path.join(exp_dir, "sliced_audios")
138
- self.wavs16k_dir = os.path.join(exp_dir, "sliced_audios_16k")
139
- os.makedirs(self.gt_wavs_dir, exist_ok=True)
140
- os.makedirs(self.wavs16k_dir, exist_ok=True)
141
-
142
- def _normalize_audio(self, audio):
143
- tmp_max = np.abs(audio).max()
144
- if tmp_max > 2.5: return None
145
- return (audio / tmp_max * (MAX_AMPLITUDE * ALPHA)) + (1 - ALPHA) * audio
146
-
147
- def process_audio_segment(self, normalized_audio, sid, idx0, idx1):
148
- if normalized_audio is None:
149
- logger.debug(f"{sid}-{idx0}-{idx1}-filtered")
150
- return
151
-
152
- wavfile.write(os.path.join(self.gt_wavs_dir, f"{sid}_{idx0}_{idx1}.wav"), self.sr, normalized_audio.astype(np.float32))
153
- wavfile.write(os.path.join(self.wavs16k_dir, f"{sid}_{idx0}_{idx1}.wav"), SAMPLE_RATE_16K, librosa.resample(normalized_audio, orig_sr=self.sr, target_sr=SAMPLE_RATE_16K, res_type="soxr_vhq").astype(np.float32))
154
-
155
- def process_audio(self, path, idx0, sid, cut_preprocess, process_effects, clean_dataset, clean_strength):
156
- try:
157
- audio = load_audio(logger, path, self.sr)
158
-
159
- if process_effects:
160
- audio = signal.lfilter(self.b_high, self.a_high, audio)
161
- audio = self._normalize_audio(audio)
162
-
163
- if clean_dataset:
164
- from main.tools.noisereduce import reduce_noise
165
- audio = reduce_noise(y=audio, sr=self.sr, prop_decrease=clean_strength, device=config.device)
166
-
167
- idx1 = 0
168
- if cut_preprocess:
169
- for audio_segment in self.slicer.slice(audio):
170
- i = 0
171
-
172
- while 1:
173
- start = int(self.sr * (self.per - OVERLAP) * i)
174
- i += 1
175
-
176
- if len(audio_segment[start:]) > (self.per + OVERLAP) * self.sr:
177
- self.process_audio_segment(audio_segment[start : start + int(self.per * self.sr)], sid, idx0, idx1)
178
- idx1 += 1
179
- else:
180
- self.process_audio_segment(audio_segment[start:], sid, idx0, idx1)
181
- idx1 += 1
182
- break
183
- else: self.process_audio_segment(audio, sid, idx0, idx1)
184
- except Exception as e:
185
- raise RuntimeError(f"{translations['process_audio_error']}: {e}")
186
-
187
- def process_file(args):
188
- pp, file, cut_preprocess, process_effects, clean_dataset, clean_strength = (args)
189
- file_path, idx0, sid = file
190
- pp.process_audio(file_path, idx0, sid, cut_preprocess, process_effects, clean_dataset, clean_strength)
191
-
192
- def preprocess_training_set(input_root, sr, num_processes, exp_dir, per, cut_preprocess, process_effects, clean_dataset, clean_strength):
193
- start_time = time.time()
194
- pp = PreProcess(sr, exp_dir, per)
195
- logger.info(translations["start_preprocess"].format(num_processes=num_processes))
196
- files = []
197
- idx = 0
198
-
199
- for root, _, filenames in os.walk(input_root):
200
- try:
201
- sid = 0 if root == input_root else int(os.path.basename(root))
202
-
203
- for f in filenames:
204
- if f.lower().endswith(("wav", "mp3", "flac", "ogg", "opus", "m4a", "mp4", "aac", "alac", "wma", "aiff", "webm", "ac3")):
205
- files.append((os.path.join(root, f), idx, sid))
206
- idx += 1
207
- except ValueError:
208
- raise ValueError(f"{translations['not_integer']} '{os.path.basename(root)}'.")
209
-
210
- with tqdm(total=len(files), ncols=100, unit="f") as pbar:
211
- with ProcessPoolExecutor(max_workers=num_processes) as executor:
212
- futures = [executor.submit(process_file, (pp, file, cut_preprocess, process_effects, clean_dataset, clean_strength)) for file in files]
213
- for future in as_completed(futures):
214
- try:
215
- future.result()
216
- except Exception as e:
217
- raise RuntimeError(f"{translations['process_error']}: {e}")
218
- pbar.update(1)
219
- logger.debug(pbar.format_meter(pbar.n, pbar.total, pbar.format_dict["elapsed"]))
220
-
221
- elapsed_time = time.time() - start_time
222
- logger.info(translations["preprocess_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
223
-
224
- def main():
225
- args = parse_arguments()
226
- experiment_directory = os.path.join("assets", "logs", args.model_name)
227
-
228
- num_processes = args.cpu_cores
229
- num_processes = 2 if num_processes is None else int(num_processes)
230
- dataset, sample_rate, cut_preprocess, preprocess_effects, clean_dataset, clean_strength = args.dataset_path, args.sample_rate, args.cut_preprocess, args.process_effects, args.clean_dataset, args.clean_strength
231
-
232
- os.makedirs(experiment_directory, exist_ok=True)
233
-
234
- if logger.hasHandlers(): logger.handlers.clear()
235
- else:
236
- console_handler = logging.StreamHandler()
237
- console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
238
- console_handler.setFormatter(console_formatter)
239
- console_handler.setLevel(logging.INFO)
240
- file_handler = logging.handlers.RotatingFileHandler(os.path.join(experiment_directory, "preprocess.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
241
- file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
242
- file_handler.setFormatter(file_formatter)
243
- file_handler.setLevel(logging.DEBUG)
244
- logger.addHandler(console_handler)
245
- logger.addHandler(file_handler)
246
- logger.setLevel(logging.DEBUG)
247
-
248
- log_data = {translations['modelname']: args.model_name, translations['export_process']: experiment_directory, translations['dataset_folder']: dataset, translations['pretrain_sr']: sample_rate, translations['cpu_core']: num_processes, translations['split_audio']: cut_preprocess, translations['preprocess_effect']: preprocess_effects, translations['clear_audio']: clean_dataset}
249
- if clean_dataset: log_data[translations['clean_strength']] = clean_strength
250
-
251
- for key, value in log_data.items():
252
- logger.debug(f"{key}: {value}")
253
-
254
- pid_path = os.path.join(experiment_directory, "preprocess_pid.txt")
255
- with open(pid_path, "w") as pid_file:
256
- pid_file.write(str(os.getpid()))
257
-
258
- try:
259
- preprocess_training_set(dataset, sample_rate, num_processes, experiment_directory, config.per_preprocess, cut_preprocess, preprocess_effects, clean_dataset, clean_strength)
260
- except Exception as e:
261
- logger.error(f"{translations['process_audio_error']} {e}")
262
- import traceback
263
- logger.debug(traceback.format_exc())
264
-
265
- if os.path.exists(pid_path): os.remove(pid_path)
266
- logger.info(f"{translations['preprocess_model_success']} {args.model_name}")
267
-
268
- if __name__ == "__main__":
269
- mp.set_start_method("spawn", force=True)
270
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/inference/separator_music.py DELETED
@@ -1,310 +0,0 @@
1
- import os
2
- import sys
3
- import time
4
- import logging
5
- import argparse
6
- import logging.handlers
7
-
8
- import numpy as np
9
-
10
- from distutils.util import strtobool
11
-
12
- sys.path.append(os.getcwd())
13
-
14
- from main.configs.config import Config
15
- from main.library.algorithm.separator import Separator
16
- from main.library.utils import pydub_convert, pydub_load
17
-
18
- config = Config()
19
- translations = config.translations
20
- logger = logging.getLogger(__name__)
21
-
22
- if logger.hasHandlers(): logger.handlers.clear()
23
- else:
24
- console_handler = logging.StreamHandler()
25
- console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
26
- console_handler.setFormatter(console_formatter)
27
- console_handler.setLevel(logging.INFO)
28
- file_handler = logging.handlers.RotatingFileHandler(os.path.join("assets", "logs", "separator.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
29
- file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
30
- file_handler.setFormatter(file_formatter)
31
- file_handler.setLevel(logging.DEBUG)
32
- logger.addHandler(console_handler)
33
- logger.addHandler(file_handler)
34
- logger.setLevel(logging.DEBUG)
35
-
36
- demucs_models = {"HT-Tuned": "htdemucs_ft.yaml", "HT-Normal": "htdemucs.yaml", "HD_MMI": "hdemucs_mmi.yaml", "HT_6S": "htdemucs_6s.yaml"}
37
- mdx_models = {"Main_340": "UVR-MDX-NET_Main_340.onnx", "Main_390": "UVR-MDX-NET_Main_390.onnx", "Main_406": "UVR-MDX-NET_Main_406.onnx", "Main_427": "UVR-MDX-NET_Main_427.onnx","Main_438": "UVR-MDX-NET_Main_438.onnx", "Inst_full_292": "UVR-MDX-NET-Inst_full_292.onnx", "Inst_HQ_1": "UVR-MDX-NET-Inst_HQ_1.onnx", "Inst_HQ_2": "UVR-MDX-NET-Inst_HQ_2.onnx", "Inst_HQ_3": "UVR-MDX-NET-Inst_HQ_3.onnx", "Inst_HQ_4": "UVR-MDX-NET-Inst_HQ_4.onnx", "Inst_HQ_5": "UVR-MDX-NET-Inst_HQ_5.onnx", "Kim_Vocal_1": "Kim_Vocal_1.onnx", "Kim_Vocal_2": "Kim_Vocal_2.onnx", "Kim_Inst": "Kim_Inst.onnx", "Inst_187_beta": "UVR-MDX-NET_Inst_187_beta.onnx", "Inst_82_beta": "UVR-MDX-NET_Inst_82_beta.onnx", "Inst_90_beta": "UVR-MDX-NET_Inst_90_beta.onnx", "Voc_FT": "UVR-MDX-NET-Voc_FT.onnx", "Crowd_HQ": "UVR-MDX-NET_Crowd_HQ_1.onnx", "MDXNET_9482": "UVR_MDXNET_9482.onnx", "Inst_1": "UVR-MDX-NET-Inst_1.onnx", "Inst_2": "UVR-MDX-NET-Inst_2.onnx", "Inst_3": "UVR-MDX-NET-Inst_3.onnx", "MDXNET_1_9703": "UVR_MDXNET_1_9703.onnx", "MDXNET_2_9682": "UVR_MDXNET_2_9682.onnx", "MDXNET_3_9662": "UVR_MDXNET_3_9662.onnx", "Inst_Main": "UVR-MDX-NET-Inst_Main.onnx", "MDXNET_Main": "UVR_MDXNET_Main.onnx"}
38
- kara_models = {"Version-1": "UVR_MDXNET_KARA.onnx", "Version-2": "UVR_MDXNET_KARA_2.onnx"}
39
-
40
- def parse_arguments():
41
- parser = argparse.ArgumentParser()
42
- parser.add_argument("--input_path", type=str, required=True)
43
- parser.add_argument("--output_path", type=str, default="./audios")
44
- parser.add_argument("--format", type=str, default="wav")
45
- parser.add_argument("--shifts", type=int, default=2)
46
- parser.add_argument("--segments_size", type=int, default=256)
47
- parser.add_argument("--overlap", type=float, default=0.25)
48
- parser.add_argument("--mdx_hop_length", type=int, default=1024)
49
- parser.add_argument("--mdx_batch_size", type=int, default=1)
50
- parser.add_argument("--clean_audio", type=lambda x: bool(strtobool(x)), default=False)
51
- parser.add_argument("--clean_strength", type=float, default=0.7)
52
- parser.add_argument("--model_name", type=str, default="HT-Normal")
53
- parser.add_argument("--kara_model", type=str, default="Version-1")
54
- parser.add_argument("--backing", type=lambda x: bool(strtobool(x)), default=False)
55
- parser.add_argument("--mdx_denoise", type=lambda x: bool(strtobool(x)), default=False)
56
- parser.add_argument("--reverb", type=lambda x: bool(strtobool(x)), default=False)
57
- parser.add_argument("--backing_reverb", type=lambda x: bool(strtobool(x)), default=False)
58
- parser.add_argument("--sample_rate", type=int, default=44100)
59
-
60
- return parser.parse_args()
61
-
62
- def main():
63
- start_time = time.time()
64
- pid_path = os.path.join("assets", "separate_pid.txt")
65
-
66
- with open(pid_path, "w") as pid_file:
67
- pid_file.write(str(os.getpid()))
68
-
69
- try:
70
- args = parse_arguments()
71
- input_path, output_path, export_format, shifts, segments_size, overlap, hop_length, batch_size, clean_audio, clean_strength, model_name, kara_model, backing, mdx_denoise, reverb, backing_reverb, sample_rate = args.input_path, args.output_path, args.format, args.shifts, args.segments_size, args.overlap, args.mdx_hop_length, args.mdx_batch_size, args.clean_audio, args.clean_strength, args.model_name, args.kara_model, args.backing, args.mdx_denoise, args.reverb, args.backing_reverb, args.sample_rate
72
-
73
- if backing_reverb and not reverb:
74
- logger.warning(translations["turn_on_dereverb"])
75
- sys.exit(1)
76
-
77
- if backing_reverb and not backing:
78
- logger.warning(translations["turn_on_separator_backing"])
79
- sys.exit(1)
80
-
81
- input_path = input_path.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
82
- output_path = os.path.dirname(output_path) or output_path
83
-
84
- log_data = {translations['audio_path']: input_path, translations['output_path']: output_path, translations['export_format']: export_format, translations['shift']: shifts, translations['segments_size']: segments_size, translations['overlap']: overlap, translations['modelname']: model_name, translations['denoise_mdx']: mdx_denoise, "Hop length": hop_length, translations['batch_size']: batch_size, translations['sr']: sample_rate}
85
-
86
- if clean_audio:
87
- log_data[translations['clear_audio']] = clean_audio
88
- log_data[translations['clean_strength']] = clean_strength
89
-
90
- if backing:
91
- log_data[translations['backing_model_ver']] = kara_model
92
- log_data[translations['separator_backing']] = backing
93
-
94
- if reverb:
95
- log_data[translations['dereveb_audio']] = reverb
96
- log_data[translations['dereveb_backing']] = backing_reverb
97
-
98
- for key, value in log_data.items():
99
- logger.debug(f"{key}: {value}")
100
-
101
- if os.path.isdir(input_path):
102
- for f in input_path:
103
- separation(f, output_path, export_format, shifts, overlap, segments_size, model_name, sample_rate, mdx_denoise, hop_length, batch_size, backing, reverb, kara_model, backing_reverb, clean_audio, clean_strength)
104
- else: separation(input_path, output_path, export_format, shifts, overlap, segments_size, model_name, sample_rate, mdx_denoise, hop_length, batch_size, backing, reverb, kara_model, backing_reverb, clean_audio, clean_strength)
105
-
106
- except Exception as e:
107
- logger.error(f"{translations['separator_error']}: {e}")
108
- import traceback
109
- logger.debug(traceback.format_exc())
110
-
111
- if os.path.exists(pid_path): os.remove(pid_path)
112
- elapsed_time = time.time() - start_time
113
- logger.info(translations["separator_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
114
-
115
- def separation(input_path, output_path, export_format, shifts, overlap, segments_size, model_name, sample_rate, mdx_denoise, hop_length, batch_size, backing, reverb, kara_model, backing_reverb, clean_audio, clean_strength):
116
- filename, _ = os.path.splitext(os.path.basename(input_path))
117
- output_path = os.path.join(output_path, filename)
118
- os.makedirs(output_path, exist_ok=True)
119
-
120
- if model_name in ["HT-Tuned", "HT-Normal", "HD_MMI", "HT_6S"]: vocals, _ = separator_music_demucs(input_path, output_path, export_format, shifts, overlap, segments_size, model_name, sample_rate)
121
- else: vocals, _ = separator_music_mdx(input_path, output_path, export_format, segments_size, overlap, mdx_denoise, model_name, hop_length, batch_size, sample_rate)
122
-
123
- if backing: main_vocals, backing_vocals = separator_backing(vocals, output_path, export_format, segments_size, overlap, mdx_denoise, kara_model, hop_length, batch_size, sample_rate)
124
- if reverb: vocals_no_reverb, main_vocals_no_reverb, backing_vocals_no_reverb = separator_reverb(output_path, export_format, segments_size, overlap, mdx_denoise, reverb, backing_reverb, hop_length, batch_size, sample_rate)
125
-
126
- original_output = os.path.join(output_path, f"Original_Vocals_No_Reverb.{export_format}") if reverb else os.path.join(output_path, f"Original_Vocals.{export_format}")
127
- main_output = os.path.join(output_path, f"Main_Vocals_No_Reverb.{export_format}") if reverb and backing_reverb else os.path.join(output_path, f"Main_Vocals.{export_format}")
128
- backing_output = os.path.join(output_path, f"Backing_Vocals_No_Reverb.{export_format}") if reverb and backing_reverb else os.path.join(output_path, f"Backing_Vocals.{export_format}")
129
-
130
- if clean_audio:
131
- import soundfile as sf
132
-
133
- logger.info(f"{translations['clear_audio']}...")
134
- vocal_data, vocal_sr = sf.read(vocals_no_reverb if reverb else vocals, dtype=np.float32)
135
-
136
- from main.tools.noisereduce import reduce_noise
137
- sf.write(original_output, reduce_noise(y=vocal_data, sr=vocal_sr, prop_decrease=clean_strength), vocal_sr, format=export_format, device=config.device)
138
-
139
- if backing:
140
- main_data, main_sr = sf.read(main_vocals_no_reverb if reverb and backing else main_vocals, dtype=np.float32)
141
- backing_data, backing_sr = sf.read(backing_vocals_no_reverb if reverb and backing_reverb else backing_vocals, dtype=np.float32)
142
-
143
- sf.write(main_output, reduce_noise(y=main_data, sr=main_sr, prop_decrease=clean_strength), main_sr, format=export_format, device=config.device)
144
- sf.write(backing_output, reduce_noise(y=backing_data, sr=backing_sr, prop_decrease=clean_strength), backing_sr, format=export_format, device=config.device)
145
-
146
- logger.info(translations["clean_audio_success"])
147
-
148
- def separator_music_demucs(input, output, format, shifts, overlap, segments_size, demucs_model, sample_rate):
149
- if not os.path.exists(input):
150
- logger.warning(translations["input_not_valid"])
151
- sys.exit(1)
152
-
153
- if not os.path.exists(output):
154
- logger.warning(translations["output_not_valid"])
155
- sys.exit(1)
156
-
157
- for i in [f"Original_Vocals.{format}", f"Instruments.{format}"]:
158
- if os.path.exists(os.path.join(output, i)): os.remove(os.path.join(output, i))
159
-
160
- logger.info(f"{translations['separator_process_2']}...")
161
- demucs_output = separator_main(audio_file=input, model_filename=demucs_models.get(demucs_model), output_format=format, output_dir=output, demucs_segment_size=(segments_size / 2), demucs_shifts=shifts, demucs_overlap=overlap, sample_rate=sample_rate)
162
-
163
- for f in demucs_output:
164
- path = os.path.join(output, f)
165
- if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
166
-
167
- if '_(Drums)_' in f: drums = path
168
- elif '_(Bass)_' in f: bass = path
169
- elif '_(Other)_' in f: other = path
170
- elif '_(Vocals)_' in f: os.rename(path, os.path.join(output, f"Original_Vocals.{format}"))
171
-
172
- pydub_convert(pydub_load(drums)).overlay(pydub_convert(pydub_load(bass))).overlay(pydub_convert(pydub_load(other))).export(os.path.join(output, f"Instruments.{format}"), format=format)
173
-
174
- for f in [drums, bass, other]:
175
- if os.path.exists(f): os.remove(f)
176
-
177
- logger.info(translations["separator_success_2"])
178
- return os.path.join(output, f"Original_Vocals.{format}"), os.path.join(output, f"Instruments.{format}")
179
-
180
- def separator_backing(input, output, format, segments_size, overlap, denoise, kara_model, hop_length, batch_size, sample_rate):
181
- if not os.path.exists(input):
182
- logger.warning(translations["input_not_valid"])
183
- sys.exit(1)
184
-
185
- if not os.path.exists(output):
186
- logger.warning(translations["output_not_valid"])
187
- sys.exit(1)
188
-
189
- for f in [f"Main_Vocals.{format}", f"Backing_Vocals.{format}"]:
190
- if os.path.exists(os.path.join(output, f)): os.remove(os.path.join(output, f))
191
-
192
- model_2 = kara_models.get(kara_model)
193
- logger.info(f"{translations['separator_process_backing']}...")
194
-
195
- backing_outputs = separator_main(audio_file=input, model_filename=model_2, output_format=format, output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise, sample_rate=sample_rate)
196
- main_output = os.path.join(output, f"Main_Vocals.{format}")
197
- backing_output = os.path.join(output, f"Backing_Vocals.{format}")
198
-
199
- for f in backing_outputs:
200
- path = os.path.join(output, f)
201
- if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
202
-
203
- if '_(Instrumental)_' in f: os.rename(path, backing_output)
204
- elif '_(Vocals)_' in f: os.rename(path, main_output)
205
-
206
- logger.info(translations["separator_process_backing_success"])
207
- return main_output, backing_output
208
-
209
- def separator_music_mdx(input, output, format, segments_size, overlap, denoise, mdx_model, hop_length, batch_size, sample_rate):
210
- if not os.path.exists(input):
211
- logger.warning(translations["input_not_valid"])
212
- sys.exit(1)
213
-
214
- if not os.path.exists(output):
215
- logger.warning(translations["output_not_valid"])
216
- sys.exit(1)
217
-
218
- for i in [f"Original_Vocals.{format}", f"Instruments.{format}"]:
219
- if os.path.exists(os.path.join(output, i)): os.remove(os.path.join(output, i))
220
-
221
- model_3 = mdx_models.get(mdx_model)
222
- logger.info(f"{translations['separator_process_2']}...")
223
-
224
- output_music = separator_main(audio_file=input, model_filename=model_3, output_format=format, output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise, sample_rate=sample_rate)
225
- original_output, instruments_output = os.path.join(output, f"Original_Vocals.{format}"), os.path.join(output, f"Instruments.{format}")
226
-
227
- for f in output_music:
228
- path = os.path.join(output, f)
229
- if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
230
-
231
- if '_(Instrumental)_' in f: os.rename(path, instruments_output)
232
- elif '_(Vocals)_' in f: os.rename(path, original_output)
233
-
234
- logger.info(translations["separator_process_backing_success"])
235
- return original_output, instruments_output
236
-
237
- def separator_reverb(output, format, segments_size, overlap, denoise, original, backing_reverb, hop_length, batch_size, sample_rate):
238
- if not os.path.exists(output):
239
- logger.warning(translations["output_not_valid"])
240
- sys.exit(1)
241
-
242
- for i in [f"Original_Vocals_Reverb.{format}", f"Main_Vocals_Reverb.{format}", f"Original_Vocals_No_Reverb.{format}", f"Main_Vocals_No_Reverb.{format}"]:
243
- if os.path.exists(os.path.join(output, i)): os.remove(os.path.join(output, i))
244
-
245
- dereveb_path = []
246
-
247
- if original:
248
- try:
249
- dereveb_path.append(os.path.join(output, [f for f in os.listdir(output) if 'Original_Vocals' in f][0]))
250
- except IndexError:
251
- logger.warning(translations["not_found_original_vocal"])
252
- sys.exit(1)
253
-
254
- if backing_reverb:
255
- try:
256
- dereveb_path.append(os.path.join(output, [f for f in os.listdir(output) if 'Main_Vocals' in f][0]))
257
- except IndexError:
258
- logger.warning(translations["not_found_main_vocal"])
259
- sys.exit(1)
260
-
261
- if backing_reverb:
262
- try:
263
- dereveb_path.append(os.path.join(output, [f for f in os.listdir(output) if 'Backing_Vocals' in f][0]))
264
- except IndexError:
265
- logger.warning(translations["not_found_backing_vocal"])
266
- sys.exit(1)
267
-
268
- for path in dereveb_path:
269
- if not os.path.exists(path):
270
- logger.warning(translations["not_found"].format(name=path))
271
- sys.exit(1)
272
-
273
- if "Original_Vocals" in path:
274
- reverb_path, no_reverb_path = os.path.join(output, f"Original_Vocals_Reverb.{format}"), os.path.join(output, f"Original_Vocals_No_Reverb.{format}")
275
- start_title, end_title = translations["process_original"], translations["process_original_success"]
276
- elif "Main_Vocals" in path:
277
- reverb_path, no_reverb_path = os.path.join(output, f"Main_Vocals_Reverb.{format}"), os.path.join(output, f"Main_Vocals_No_Reverb.{format}")
278
- start_title, end_title = translations["process_main"], translations["process_main_success"]
279
- elif "Backing_Vocals" in path:
280
- reverb_path, no_reverb_path = os.path.join(output, f"Backing_Vocals_Reverb.{format}"), os.path.join(output, f"Backing_Vocals_No_Reverb.{format}")
281
- start_title, end_title = translations["process_backing"], translations["process_backing_success"]
282
-
283
- logger.info(start_title)
284
- output_dereveb = separator_main(audio_file=path, model_filename="Reverb_HQ_By_FoxJoy.onnx", output_format=format, output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise, sample_rate=sample_rate)
285
-
286
- for f in output_dereveb:
287
- path = os.path.join(output, f)
288
- if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
289
-
290
- if '_(Reverb)_' in f: os.rename(path, reverb_path)
291
- elif '_(No Reverb)_' in f: os.rename(path, no_reverb_path)
292
-
293
- logger.info(end_title)
294
-
295
- return (os.path.join(output, f"Original_Vocals_No_Reverb.{format}") if original else None), (os.path.join(output, f"Main_Vocals_No_Reverb.{format}") if backing_reverb else None), (os.path.join(output, f"Backing_Vocals_No_Reverb.{format}") if backing_reverb else None)
296
-
297
- def separator_main(audio_file=None, model_filename="UVR-MDX-NET_Main_340.onnx", output_format="wav", output_dir=".", mdx_segment_size=256, mdx_overlap=0.25, mdx_batch_size=1, mdx_hop_length=1024, mdx_enable_denoise=True, demucs_segment_size=256, demucs_shifts=2, demucs_overlap=0.25, sample_rate=44100):
298
- try:
299
- separator = Separator(logger=logger, log_formatter=file_formatter, log_level=logging.INFO, output_dir=output_dir, output_format=output_format, output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=sample_rate, mdx_params={"hop_length": mdx_hop_length, "segment_size": mdx_segment_size, "overlap": mdx_overlap, "batch_size": mdx_batch_size, "enable_denoise": mdx_enable_denoise}, demucs_params={"segment_size": demucs_segment_size, "shifts": demucs_shifts, "overlap": demucs_overlap, "segments_enabled": True})
300
- separator.load_model(model_filename=model_filename)
301
-
302
- return separator.separate(audio_file)
303
- except:
304
- logger.debug(translations["default_setting"])
305
- separator = Separator(logger=logger, log_formatter=file_formatter, log_level=logging.INFO, output_dir=output_dir, output_format=output_format, output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=44100, mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": mdx_enable_denoise}, demucs_params={"segment_size": 128, "shifts": 2, "overlap": 0.25, "segments_enabled": True})
306
- separator.load_model(model_filename=model_filename)
307
-
308
- return separator.separate(audio_file)
309
-
310
- if __name__ == "__main__": main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/inference/train.py DELETED
@@ -1,990 +0,0 @@
1
- import os
2
- import sys
3
- import glob
4
- import json
5
- import torch
6
- import hashlib
7
- import logging
8
- import argparse
9
- import datetime
10
- import warnings
11
- import logging.handlers
12
-
13
- import numpy as np
14
- import soundfile as sf
15
- import matplotlib.pyplot as plt
16
- import torch.distributed as dist
17
- import torch.utils.data as tdata
18
- import torch.multiprocessing as mp
19
-
20
- from tqdm import tqdm
21
- from collections import OrderedDict
22
- from random import randint, shuffle
23
- from torch.utils.checkpoint import checkpoint
24
- from torch.cuda.amp import GradScaler, autocast
25
- from torch.utils.tensorboard import SummaryWriter
26
-
27
- from time import time as ttime
28
- from torch.nn import functional as F
29
- from distutils.util import strtobool
30
- from librosa.filters import mel as librosa_mel_fn
31
- from torch.nn.parallel import DistributedDataParallel as DDP
32
- from torch.nn.utils.parametrizations import spectral_norm, weight_norm
33
-
34
- sys.path.append(os.getcwd())
35
-
36
- from main.configs.config import Config
37
- from main.library.algorithm.residuals import LRELU_SLOPE
38
- from main.library.algorithm.synthesizers import Synthesizer
39
- from main.library.algorithm.commons import get_padding, slice_segments, clip_grad_value
40
-
41
- MATPLOTLIB_FLAG = False
42
- main_config = Config()
43
- translations = main_config.translations
44
- warnings.filterwarnings("ignore")
45
- logging.getLogger("torch").setLevel(logging.ERROR)
46
-
47
- class HParams:
48
- def __init__(self, **kwargs):
49
- for k, v in kwargs.items():
50
- self[k] = HParams(**v) if isinstance(v, dict) else v
51
-
52
- def keys(self):
53
- return self.__dict__.keys()
54
-
55
- def items(self):
56
- return self.__dict__.items()
57
-
58
- def values(self):
59
- return self.__dict__.values()
60
-
61
- def __len__(self):
62
- return len(self.__dict__)
63
-
64
- def __getitem__(self, key):
65
- return self.__dict__[key]
66
-
67
- def __setitem__(self, key, value):
68
- self.__dict__[key] = value
69
-
70
- def __contains__(self, key):
71
- return key in self.__dict__
72
-
73
- def __repr__(self):
74
- return repr(self.__dict__)
75
-
76
- def parse_arguments():
77
- parser = argparse.ArgumentParser()
78
- parser.add_argument("--model_name", type=str, required=True)
79
- parser.add_argument("--rvc_version", type=str, default="v2")
80
- parser.add_argument("--save_every_epoch", type=int, required=True)
81
- parser.add_argument("--save_only_latest", type=lambda x: bool(strtobool(x)), default=True)
82
- parser.add_argument("--save_every_weights", type=lambda x: bool(strtobool(x)), default=True)
83
- parser.add_argument("--total_epoch", type=int, default=300)
84
- parser.add_argument("--sample_rate", type=int, required=True)
85
- parser.add_argument("--batch_size", type=int, default=8)
86
- parser.add_argument("--gpu", type=str, default="0")
87
- parser.add_argument("--pitch_guidance", type=lambda x: bool(strtobool(x)), default=True)
88
- parser.add_argument("--g_pretrained_path", type=str, default="")
89
- parser.add_argument("--d_pretrained_path", type=str, default="")
90
- parser.add_argument("--overtraining_detector", type=lambda x: bool(strtobool(x)), default=False)
91
- parser.add_argument("--overtraining_threshold", type=int, default=50)
92
- parser.add_argument("--cleanup", type=lambda x: bool(strtobool(x)), default=False)
93
- parser.add_argument("--cache_data_in_gpu", type=lambda x: bool(strtobool(x)), default=False)
94
- parser.add_argument("--model_author", type=str)
95
- parser.add_argument("--vocoder", type=str, default="Default")
96
- parser.add_argument("--checkpointing", type=lambda x: bool(strtobool(x)), default=False)
97
- parser.add_argument("--deterministic", type=lambda x: bool(strtobool(x)), default=False)
98
- parser.add_argument("--benchmark", type=lambda x: bool(strtobool(x)), default=False)
99
-
100
- return parser.parse_args()
101
-
102
- args = parse_arguments()
103
- model_name, save_every_epoch, total_epoch, pretrainG, pretrainD, version, gpus, batch_size, sample_rate, pitch_guidance, save_only_latest, save_every_weights, cache_data_in_gpu, overtraining_detector, overtraining_threshold, cleanup, model_author, vocoder, checkpointing = args.model_name, args.save_every_epoch, args.total_epoch, args.g_pretrained_path, args.d_pretrained_path, args.rvc_version, args.gpu, args.batch_size, args.sample_rate, args.pitch_guidance, args.save_only_latest, args.save_every_weights, args.cache_data_in_gpu, args.overtraining_detector, args.overtraining_threshold, args.cleanup, args.model_author, args.vocoder, args.checkpointing
104
-
105
- experiment_dir = os.path.join("assets", "logs", model_name)
106
- training_file_path = os.path.join(experiment_dir, "training_data.json")
107
- config_save_path = os.path.join(experiment_dir, "config.json")
108
- torch.backends.cudnn.deterministic = args.deterministic
109
- torch.backends.cudnn.benchmark = args.benchmark
110
- lowest_value = {"step": 0, "value": float("inf"), "epoch": 0}
111
- global_step, last_loss_gen_all, overtrain_save_epoch = 0, 0, 0
112
- loss_gen_history, smoothed_loss_gen_history, loss_disc_history, smoothed_loss_disc_history = [], [], [], []
113
-
114
- with open(config_save_path, "r") as f:
115
- config = json.load(f)
116
-
117
- config = HParams(**config)
118
- config.data.training_files = os.path.join(experiment_dir, "filelist.txt")
119
- logger = logging.getLogger(__name__)
120
-
121
- if logger.hasHandlers(): logger.handlers.clear()
122
- else:
123
- console_handler = logging.StreamHandler()
124
- console_handler.setFormatter(logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S"))
125
- console_handler.setLevel(logging.INFO)
126
- file_handler = logging.handlers.RotatingFileHandler(os.path.join(experiment_dir, "train.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
127
- file_handler.setFormatter(logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S"))
128
- file_handler.setLevel(logging.DEBUG)
129
- logger.addHandler(console_handler)
130
- logger.addHandler(file_handler)
131
- logger.setLevel(logging.DEBUG)
132
-
133
- log_data = {translations['modelname']: model_name, translations["save_every_epoch"]: save_every_epoch, translations["total_e"]: total_epoch, translations["dorg"].format(pretrainG=pretrainG, pretrainD=pretrainD): "", translations['training_version']: version, "Gpu": gpus, translations['batch_size']: batch_size, translations['pretrain_sr']: sample_rate, translations['training_f0']: pitch_guidance, translations['save_only_latest']: save_only_latest, translations['save_every_weights']: save_every_weights, translations['cache_in_gpu']: cache_data_in_gpu, translations['overtraining_detector']: overtraining_detector, translations['threshold']: overtraining_threshold, translations['cleanup_training']: cleanup, translations['memory_efficient_training']: checkpointing}
134
- if model_author: log_data[translations["model_author"].format(model_author=model_author)] = ""
135
- if vocoder != "Default": log_data[translations['vocoder']] = vocoder
136
-
137
- for key, value in log_data.items():
138
- logger.debug(f"{key}: {value}" if value != "" else f"{key} {value}")
139
-
140
- def main():
141
- global training_file_path, last_loss_gen_all, smoothed_loss_gen_history, loss_gen_history, loss_disc_history, smoothed_loss_disc_history, overtrain_save_epoch, model_author, vocoder, checkpointing, gpus
142
-
143
- try:
144
- os.environ["MASTER_ADDR"] = "localhost"
145
- os.environ["MASTER_PORT"] = str(randint(20000, 55555))
146
-
147
- if torch.cuda.is_available():
148
- device, gpus = torch.device("cuda"), [int(item) for item in gpus.split("-")]
149
- n_gpus = len(gpus)
150
- elif torch.backends.mps.is_available():
151
- device, gpus = torch.device("mps"), [0]
152
- n_gpus = 1
153
- else:
154
- device, gpus = torch.device("cpu"), [0]
155
- n_gpus = 1
156
- logger.warning(translations["not_gpu"])
157
-
158
- def start():
159
- children = []
160
- pid_data = {"process_pids": []}
161
-
162
- with open(config_save_path, "r") as pid_file:
163
- try:
164
- pid_data.update(json.load(pid_file))
165
- except json.JSONDecodeError:
166
- pass
167
-
168
- with open(config_save_path, "w") as pid_file:
169
- for rank, device_id in enumerate(gpus):
170
- subproc = mp.Process(target=run, args=(rank, n_gpus, experiment_dir, pretrainG, pretrainD, pitch_guidance, total_epoch, save_every_weights, config, device, device_id, model_author, vocoder, checkpointing))
171
- children.append(subproc)
172
- subproc.start()
173
- pid_data["process_pids"].append(subproc.pid)
174
-
175
- json.dump(pid_data, pid_file, indent=4)
176
-
177
- for i in range(n_gpus):
178
- children[i].join()
179
-
180
- def load_from_json(file_path):
181
- if os.path.exists(file_path):
182
- with open(file_path, "r") as f:
183
- data = json.load(f)
184
- return (data.get("loss_disc_history", []), data.get("smoothed_loss_disc_history", []), data.get("loss_gen_history", []), data.get("smoothed_loss_gen_history", []))
185
- return [], [], [], []
186
-
187
- def continue_overtrain_detector(training_file_path):
188
- if overtraining_detector and os.path.exists(training_file_path): (loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history) = load_from_json(training_file_path)
189
-
190
- if cleanup:
191
- for root, dirs, files in os.walk(experiment_dir, topdown=False):
192
- for name in files:
193
- file_path = os.path.join(root, name)
194
- _, file_extension = os.path.splitext(name)
195
- if (file_extension == ".0" or (name.startswith("D_") and file_extension == ".pth") or (name.startswith("G_") and file_extension == ".pth") or (file_extension == ".index")): os.remove(file_path)
196
-
197
- for name in dirs:
198
- if name == "eval":
199
- folder_path = os.path.join(root, name)
200
- for item in os.listdir(folder_path):
201
- item_path = os.path.join(folder_path, item)
202
- if os.path.isfile(item_path): os.remove(item_path)
203
- os.rmdir(folder_path)
204
-
205
- continue_overtrain_detector(training_file_path)
206
- start()
207
- except Exception as e:
208
- logger.error(f"{translations['training_error']} {e}")
209
- import traceback
210
- logger.debug(traceback.format_exc())
211
-
212
- def plot_spectrogram_to_numpy(spectrogram):
213
- global MATPLOTLIB_FLAG
214
-
215
- if not MATPLOTLIB_FLAG:
216
- plt.switch_backend("Agg")
217
- MATPLOTLIB_FLAG = True
218
-
219
- fig, ax = plt.subplots(figsize=(10, 2))
220
- plt.colorbar(ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none"), ax=ax)
221
- plt.xlabel("Frames")
222
- plt.ylabel("Channels")
223
- plt.tight_layout()
224
- fig.canvas.draw()
225
- plt.close(fig)
226
-
227
- try:
228
- data = np.array(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8).reshape(fig.canvas.get_width_height()[::-1] + (4,))[:, :, :3]
229
- except:
230
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="").reshape(fig.canvas.get_width_height()[::-1] + (3,))
231
-
232
- return data
233
-
234
- def verify_checkpoint_shapes(checkpoint_path, model):
235
- checkpoint = torch.load(checkpoint_path, map_location="cpu")
236
- checkpoint_state_dict = checkpoint["model"]
237
- try:
238
- model_state_dict = model.module.load_state_dict(checkpoint_state_dict) if hasattr(model, "module") else model.load_state_dict(checkpoint_state_dict)
239
- except RuntimeError:
240
- logger.warning(translations["checkpointing_err"])
241
- sys.exit(1)
242
- else: del checkpoint, checkpoint_state_dict, model_state_dict
243
-
244
- def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sample_rate=22050):
245
- for k, v in scalars.items():
246
- writer.add_scalar(k, v, global_step)
247
-
248
- for k, v in histograms.items():
249
- writer.add_histogram(k, v, global_step)
250
-
251
- for k, v in images.items():
252
- writer.add_image(k, v, global_step, dataformats="HWC")
253
-
254
- for k, v in audios.items():
255
- writer.add_audio(k, v, global_step, audio_sample_rate)
256
-
257
- def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
258
- assert os.path.isfile(checkpoint_path), translations["not_found_checkpoint"].format(checkpoint_path=checkpoint_path)
259
- checkpoint_dict = replace_keys_in_dict(replace_keys_in_dict(torch.load(checkpoint_path, map_location="cpu"), ".weight_v", ".parametrizations.weight.original1"), ".weight_g", ".parametrizations.weight.original0")
260
- new_state_dict = {k: checkpoint_dict["model"].get(k, v) for k, v in (model.module.state_dict() if hasattr(model, "module") else model.state_dict()).items()}
261
-
262
- if hasattr(model, "module"): model.module.load_state_dict(new_state_dict, strict=False)
263
- else: model.load_state_dict(new_state_dict, strict=False)
264
-
265
- if optimizer and load_opt == 1: optimizer.load_state_dict(checkpoint_dict.get("optimizer", {}))
266
- logger.debug(translations["save_checkpoint"].format(checkpoint_path=checkpoint_path, checkpoint_dict=checkpoint_dict['iteration']))
267
- return (model, optimizer, checkpoint_dict.get("learning_rate", 0), checkpoint_dict["iteration"])
268
-
269
- def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
270
- state_dict = (model.module.state_dict() if hasattr(model, "module") else model.state_dict())
271
- torch.save(replace_keys_in_dict(replace_keys_in_dict({"model": state_dict, "iteration": iteration, "optimizer": optimizer.state_dict(), "learning_rate": learning_rate}, ".parametrizations.weight.original1", ".weight_v"), ".parametrizations.weight.original0", ".weight_g"), checkpoint_path)
272
- logger.info(translations["save_model"].format(checkpoint_path=checkpoint_path, iteration=iteration))
273
-
274
- def latest_checkpoint_path(dir_path, regex="G_*.pth"):
275
- checkpoints = sorted(glob.glob(os.path.join(dir_path, regex)), key=lambda f: int("".join(filter(str.isdigit, f))))
276
- return checkpoints[-1] if checkpoints else None
277
-
278
- def load_wav_to_torch(full_path):
279
- data, sample_rate = sf.read(full_path, dtype=np.float32)
280
- return torch.FloatTensor(data.astype(np.float32)), sample_rate
281
-
282
- def load_filepaths_and_text(filename, split="|"):
283
- with open(filename, encoding="utf-8") as f:
284
- return [line.strip().split(split) for line in f]
285
-
286
- def feature_loss(fmap_r, fmap_g):
287
- loss = 0
288
- for dr, dg in zip(fmap_r, fmap_g):
289
- for rl, gl in zip(dr, dg):
290
- loss += torch.mean(torch.abs(rl.float().detach() - gl.float()))
291
- return loss * 2
292
-
293
- def discriminator_loss(disc_real_outputs, disc_generated_outputs):
294
- loss = 0
295
- r_losses, g_losses = [], []
296
-
297
- for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
298
- dr = dr.float()
299
- dg = dg.float()
300
- r_loss = torch.mean((1 - dr) ** 2)
301
- g_loss = torch.mean(dg**2)
302
- loss += r_loss + g_loss
303
- r_losses.append(r_loss.item())
304
- g_losses.append(g_loss.item())
305
- return loss, r_losses, g_losses
306
-
307
- def generator_loss(disc_outputs):
308
- loss = 0
309
- gen_losses = []
310
-
311
- for dg in disc_outputs:
312
- l = torch.mean((1 - dg.float()) ** 2)
313
- gen_losses.append(l)
314
- loss += l
315
- return loss, gen_losses
316
-
317
- def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
318
- z_p = z_p.float()
319
- logs_q = logs_q.float()
320
- m_p = m_p.float()
321
- logs_p = logs_p.float()
322
- z_mask = z_mask.float()
323
- kl = logs_p - logs_q - 0.5
324
- kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
325
- return torch.sum(kl * z_mask) / torch.sum(z_mask)
326
-
327
- class TextAudioLoaderMultiNSFsid(tdata.Dataset):
328
- def __init__(self, hparams):
329
- self.audiopaths_and_text = load_filepaths_and_text(hparams.training_files)
330
- self.max_wav_value = hparams.max_wav_value
331
- self.sample_rate = hparams.sample_rate
332
- self.filter_length = hparams.filter_length
333
- self.hop_length = hparams.hop_length
334
- self.win_length = hparams.win_length
335
- self.sample_rate = hparams.sample_rate
336
- self.min_text_len = getattr(hparams, "min_text_len", 1)
337
- self.max_text_len = getattr(hparams, "max_text_len", 5000)
338
- self._filter()
339
-
340
- def _filter(self):
341
- audiopaths_and_text_new, lengths = [], []
342
- for audiopath, text, pitch, pitchf, dv in self.audiopaths_and_text:
343
- if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
344
- audiopaths_and_text_new.append([audiopath, text, pitch, pitchf, dv])
345
- lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
346
-
347
- self.audiopaths_and_text = audiopaths_and_text_new
348
- self.lengths = lengths
349
-
350
- def get_sid(self, sid):
351
- try:
352
- sid = torch.LongTensor([int(sid)])
353
- except ValueError as e:
354
- logger.error(translations["sid_error"].format(sid=sid, e=e))
355
- sid = torch.LongTensor([0])
356
- return sid
357
-
358
- def get_audio_text_pair(self, audiopath_and_text):
359
- phone, pitch, pitchf = self.get_labels(audiopath_and_text[1], audiopath_and_text[2], audiopath_and_text[3])
360
- spec, wav = self.get_audio(audiopath_and_text[0])
361
- dv = self.get_sid(audiopath_and_text[4])
362
- len_phone = phone.size()[0]
363
- len_spec = spec.size()[-1]
364
-
365
- if len_phone != len_spec:
366
- len_min = min(len_phone, len_spec)
367
- len_wav = len_min * self.hop_length
368
- spec, wav, phone = spec[:, :len_min], wav[:, :len_wav], phone[:len_min, :]
369
- pitch, pitchf = pitch[:len_min], pitchf[:len_min]
370
- return (spec, wav, phone, pitch, pitchf, dv)
371
-
372
- def get_labels(self, phone, pitch, pitchf):
373
- phone = np.repeat(np.load(phone), 2, axis=0)
374
- n_num = min(phone.shape[0], 900)
375
- return torch.FloatTensor(phone[:n_num, :]), torch.LongTensor(np.load(pitch)[:n_num]), torch.FloatTensor(np.load(pitchf)[:n_num])
376
-
377
- def get_audio(self, filename):
378
- audio, sample_rate = load_wav_to_torch(filename)
379
- if sample_rate != self.sample_rate: raise ValueError(translations["sr_does_not_match"].format(sample_rate=sample_rate, sample_rate2=self.sample_rate))
380
- audio_norm = audio.unsqueeze(0)
381
- spec_filename = filename.replace(".wav", ".spec.pt")
382
-
383
- if os.path.exists(spec_filename):
384
- try:
385
- spec = torch.load(spec_filename)
386
- except Exception as e:
387
- logger.error(translations["spec_error"].format(spec_filename=spec_filename, e=e))
388
- spec = torch.squeeze(spectrogram_torch(audio_norm, self.filter_length, self.hop_length, self.win_length, center=False), 0)
389
- torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
390
- else:
391
- spec = torch.squeeze(spectrogram_torch(audio_norm, self.filter_length, self.hop_length, self.win_length, center=False), 0)
392
- torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
393
- return spec, audio_norm
394
-
395
- def __getitem__(self, index):
396
- return self.get_audio_text_pair(self.audiopaths_and_text[index])
397
-
398
- def __len__(self):
399
- return len(self.audiopaths_and_text)
400
-
401
- class TextAudioCollateMultiNSFsid:
402
- def __init__(self, return_ids=False):
403
- self.return_ids = return_ids
404
-
405
- def __call__(self, batch):
406
- _, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True)
407
- spec_lengths, wave_lengths = torch.LongTensor(len(batch)), torch.LongTensor(len(batch))
408
- spec_padded, wave_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max([x[0].size(1) for x in batch])), torch.FloatTensor(len(batch), 1, max([x[1].size(1) for x in batch]))
409
- spec_padded.zero_()
410
- wave_padded.zero_()
411
- max_phone_len = max([x[2].size(0) for x in batch])
412
- phone_lengths, phone_padded = torch.LongTensor(len(batch)), torch.FloatTensor(len(batch), max_phone_len, batch[0][2].shape[1])
413
- pitch_padded, pitchf_padded = torch.LongTensor(len(batch), max_phone_len), torch.FloatTensor(len(batch), max_phone_len)
414
- phone_padded.zero_()
415
- pitch_padded.zero_()
416
- pitchf_padded.zero_()
417
- sid = torch.LongTensor(len(batch))
418
-
419
- for i in range(len(ids_sorted_decreasing)):
420
- row = batch[ids_sorted_decreasing[i]]
421
- spec = row[0]
422
- spec_padded[i, :, : spec.size(1)] = spec
423
- spec_lengths[i] = spec.size(1)
424
- wave = row[1]
425
- wave_padded[i, :, : wave.size(1)] = wave
426
- wave_lengths[i] = wave.size(1)
427
- phone = row[2]
428
- phone_padded[i, : phone.size(0), :] = phone
429
- phone_lengths[i] = phone.size(0)
430
- pitch = row[3]
431
- pitch_padded[i, : pitch.size(0)] = pitch
432
- pitchf = row[4]
433
- pitchf_padded[i, : pitchf.size(0)] = pitchf
434
- sid[i] = row[5]
435
- return (phone_padded, phone_lengths, pitch_padded, pitchf_padded, spec_padded, spec_lengths, wave_padded, wave_lengths, sid)
436
-
437
- class TextAudioLoader(tdata.Dataset):
438
- def __init__(self, hparams):
439
- self.audiopaths_and_text = load_filepaths_and_text(hparams.training_files)
440
- self.max_wav_value = hparams.max_wav_value
441
- self.sample_rate = hparams.sample_rate
442
- self.filter_length = hparams.filter_length
443
- self.hop_length = hparams.hop_length
444
- self.win_length = hparams.win_length
445
- self.sample_rate = hparams.sample_rate
446
- self.min_text_len = getattr(hparams, "min_text_len", 1)
447
- self.max_text_len = getattr(hparams, "max_text_len", 5000)
448
- self._filter()
449
-
450
- def _filter(self):
451
- audiopaths_and_text_new, lengths = [], []
452
- for entry in self.audiopaths_and_text:
453
- if len(entry) >= 3:
454
- audiopath, text, dv = entry[:3]
455
- if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
456
- audiopaths_and_text_new.append([audiopath, text, dv])
457
- lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
458
-
459
- self.audiopaths_and_text = audiopaths_and_text_new
460
- self.lengths = lengths
461
-
462
- def get_sid(self, sid):
463
- try:
464
- sid = torch.LongTensor([int(sid)])
465
- except ValueError as e:
466
- logger.error(translations["sid_error"].format(sid=sid, e=e))
467
- sid = torch.LongTensor([0])
468
- return sid
469
-
470
- def get_audio_text_pair(self, audiopath_and_text):
471
- phone = self.get_labels(audiopath_and_text[1])
472
- spec, wav = self.get_audio(audiopath_and_text[0])
473
- dv = self.get_sid(audiopath_and_text[2])
474
- len_phone = phone.size()[0]
475
- len_spec = spec.size()[-1]
476
-
477
- if len_phone != len_spec:
478
- len_min = min(len_phone, len_spec)
479
- len_wav = len_min * self.hop_length
480
- spec = spec[:, :len_min]
481
- wav = wav[:, :len_wav]
482
- phone = phone[:len_min, :]
483
- return (spec, wav, phone, dv)
484
-
485
- def get_labels(self, phone):
486
- phone = np.repeat(np.load(phone), 2, axis=0)
487
- return torch.FloatTensor(phone[:min(phone.shape[0], 900), :])
488
-
489
- def get_audio(self, filename):
490
- audio, sample_rate = load_wav_to_torch(filename)
491
- if sample_rate != self.sample_rate: raise ValueError(translations["sr_does_not_match"].format(sample_rate=sample_rate, sample_rate2=self.sample_rate))
492
- audio_norm = audio.unsqueeze(0)
493
- spec_filename = filename.replace(".wav", ".spec.pt")
494
-
495
- if os.path.exists(spec_filename):
496
- try:
497
- spec = torch.load(spec_filename)
498
- except Exception as e:
499
- logger.error(translations["spec_error"].format(spec_filename=spec_filename, e=e))
500
- spec = torch.squeeze(spectrogram_torch(audio_norm, self.filter_length, self.hop_length, self.win_length, center=False), 0)
501
- torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
502
- else:
503
- spec = torch.squeeze(spectrogram_torch(audio_norm, self.filter_length, self.hop_length, self.win_length, center=False), 0)
504
- torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
505
- return spec, audio_norm
506
-
507
- def __getitem__(self, index):
508
- return self.get_audio_text_pair(self.audiopaths_and_text[index])
509
-
510
- def __len__(self):
511
- return len(self.audiopaths_and_text)
512
-
513
- class TextAudioCollate:
514
- def __init__(self, return_ids=False):
515
- self.return_ids = return_ids
516
-
517
- def __call__(self, batch):
518
- _, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True)
519
- spec_lengths, wave_lengths = torch.LongTensor(len(batch)), torch.LongTensor(len(batch))
520
- spec_padded, wave_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max([x[0].size(1) for x in batch])), torch.FloatTensor(len(batch), 1, max([x[1].size(1) for x in batch]))
521
- spec_padded.zero_()
522
- wave_padded.zero_()
523
- max_phone_len = max([x[2].size(0) for x in batch])
524
- phone_lengths, phone_padded = torch.LongTensor(len(batch)), torch.FloatTensor(len(batch), max_phone_len, batch[0][2].shape[1])
525
- phone_padded.zero_()
526
- sid = torch.LongTensor(len(batch))
527
- for i in range(len(ids_sorted_decreasing)):
528
- row = batch[ids_sorted_decreasing[i]]
529
- spec = row[0]
530
- spec_padded[i, :, : spec.size(1)] = spec
531
- spec_lengths[i] = spec.size(1)
532
- wave = row[1]
533
- wave_padded[i, :, : wave.size(1)] = wave
534
- wave_lengths[i] = wave.size(1)
535
- phone = row[2]
536
- phone_padded[i, : phone.size(0), :] = phone
537
- phone_lengths[i] = phone.size(0)
538
- sid[i] = row[3]
539
- return (phone_padded, phone_lengths, spec_padded, spec_lengths, wave_padded, wave_lengths, sid)
540
-
541
- class DistributedBucketSampler(tdata.distributed.DistributedSampler):
542
- def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
543
- super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
544
- self.lengths = dataset.lengths
545
- self.batch_size = batch_size
546
- self.boundaries = boundaries
547
- self.buckets, self.num_samples_per_bucket = self._create_buckets()
548
- self.total_size = sum(self.num_samples_per_bucket)
549
- self.num_samples = self.total_size // self.num_replicas
550
-
551
- def _create_buckets(self):
552
- buckets = [[] for _ in range(len(self.boundaries) - 1)]
553
- for i in range(len(self.lengths)):
554
- idx_bucket = self._bisect(self.lengths[i])
555
- if idx_bucket != -1: buckets[idx_bucket].append(i)
556
-
557
- for i in range(len(buckets) - 1, -1, -1):
558
- if len(buckets[i]) == 0:
559
- buckets.pop(i)
560
- self.boundaries.pop(i + 1)
561
-
562
- num_samples_per_bucket = []
563
- for i in range(len(buckets)):
564
- len_bucket = len(buckets[i])
565
- total_batch_size = self.num_replicas * self.batch_size
566
- num_samples_per_bucket.append(len_bucket + ((total_batch_size - (len_bucket % total_batch_size)) % total_batch_size))
567
- return buckets, num_samples_per_bucket
568
-
569
- def __iter__(self):
570
- g = torch.Generator()
571
- g.manual_seed(self.epoch)
572
- indices, batches = [], []
573
- if self.shuffle:
574
- for bucket in self.buckets:
575
- indices.append(torch.randperm(len(bucket), generator=g).tolist())
576
- else:
577
- for bucket in self.buckets:
578
- indices.append(list(range(len(bucket))))
579
-
580
- for i in range(len(self.buckets)):
581
- bucket = self.buckets[i]
582
- len_bucket = len(bucket)
583
- ids_bucket = indices[i]
584
- rem = self.num_samples_per_bucket[i] - len_bucket
585
- ids_bucket = (ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[: (rem % len_bucket)])[self.rank :: self.num_replicas]
586
-
587
- for j in range(len(ids_bucket) // self.batch_size):
588
- batches.append([bucket[idx] for idx in ids_bucket[j * self.batch_size : (j + 1) * self.batch_size]])
589
-
590
- if self.shuffle: batches = [batches[i] for i in torch.randperm(len(batches), generator=g).tolist()]
591
- self.batches = batches
592
- assert len(self.batches) * self.batch_size == self.num_samples
593
- return iter(self.batches)
594
-
595
- def _bisect(self, x, lo=0, hi=None):
596
- if hi is None: hi = len(self.boundaries) - 1
597
-
598
- if hi > lo:
599
- mid = (hi + lo) // 2
600
- if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: return mid
601
- elif x <= self.boundaries[mid]: return self._bisect(x, lo, mid)
602
- else: return self._bisect(x, mid + 1, hi)
603
- else: return -1
604
-
605
- def __len__(self):
606
- return self.num_samples // self.batch_size
607
-
608
- class MultiPeriodDiscriminator(torch.nn.Module):
609
- def __init__(self, version, use_spectral_norm=False, checkpointing=False):
610
- super(MultiPeriodDiscriminator, self).__init__()
611
- self.checkpointing = checkpointing
612
- periods = ([2, 3, 5, 7, 11, 17] if version == "v1" else [2, 3, 5, 7, 11, 17, 23, 37])
613
- self.discriminators = torch.nn.ModuleList([DiscriminatorS(use_spectral_norm=use_spectral_norm, checkpointing=checkpointing)] + [DiscriminatorP(p, use_spectral_norm=use_spectral_norm, checkpointing=checkpointing) for p in periods])
614
-
615
- def forward(self, y, y_hat):
616
- y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
617
- for d in self.discriminators:
618
- if self.training and self.checkpointing:
619
- def forward_discriminator(d, y, y_hat):
620
- y_d_r, fmap_r = d(y)
621
- y_d_g, fmap_g = d(y_hat)
622
- return y_d_r, fmap_r, y_d_g, fmap_g
623
- y_d_r, fmap_r, y_d_g, fmap_g = checkpoint(forward_discriminator, d, y, y_hat, use_reentrant=False)
624
- else:
625
- y_d_r, fmap_r = d(y)
626
- y_d_g, fmap_g = d(y_hat)
627
-
628
- y_d_rs.append(y_d_r); fmap_rs.append(fmap_r)
629
- y_d_gs.append(y_d_g); fmap_gs.append(fmap_g)
630
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
631
-
632
- class DiscriminatorS(torch.nn.Module):
633
- def __init__(self, use_spectral_norm=False, checkpointing=False):
634
- super(DiscriminatorS, self).__init__()
635
- self.checkpointing = checkpointing
636
- norm_f = spectral_norm if use_spectral_norm else weight_norm
637
- self.convs = torch.nn.ModuleList([norm_f(torch.nn.Conv1d(1, 16, 15, 1, padding=7)), norm_f(torch.nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)), norm_f(torch.nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)), norm_f(torch.nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)), norm_f(torch.nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), norm_f(torch.nn.Conv1d(1024, 1024, 5, 1, padding=2))])
638
- self.conv_post = norm_f(torch.nn.Conv1d(1024, 1, 3, 1, padding=1))
639
- self.lrelu = torch.nn.LeakyReLU(LRELU_SLOPE)
640
-
641
- def forward(self, x):
642
- fmap = []
643
- for conv in self.convs:
644
- x = checkpoint(self.lrelu, checkpoint(conv, x, use_reentrant = False), use_reentrant = False) if self.training and self.checkpointing else self.lrelu(conv(x))
645
- fmap.append(x)
646
-
647
- x = self.conv_post(x)
648
- fmap.append(x)
649
- return torch.flatten(x, 1, -1), fmap
650
-
651
- class DiscriminatorP(torch.nn.Module):
652
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, checkpointing=False):
653
- super(DiscriminatorP, self).__init__()
654
- self.period = period
655
- self.checkpointing = checkpointing
656
- norm_f = spectral_norm if use_spectral_norm else weight_norm
657
- self.convs = torch.nn.ModuleList([norm_f(torch.nn.Conv2d(in_ch, out_ch, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))) for in_ch, out_ch in zip([1, 32, 128, 512, 1024], [32, 128, 512, 1024, 1024])])
658
- self.conv_post = norm_f(torch.nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
659
- self.lrelu = torch.nn.LeakyReLU(LRELU_SLOPE)
660
-
661
- def forward(self, x):
662
- fmap = []
663
- b, c, t = x.shape
664
-
665
- if t % self.period != 0: x = F.pad(x, (0, (self.period - (t % self.period))), "reflect")
666
- x = x.view(b, c, -1, self.period)
667
-
668
- for conv in self.convs:
669
- x = checkpoint(self.lrelu, checkpoint(conv, x, use_reentrant = False), use_reentrant = False) if self.training and self.checkpointing else self.lrelu(conv(x))
670
- fmap.append(x)
671
-
672
- x = self.conv_post(x)
673
- fmap.append(x)
674
- return torch.flatten(x, 1, -1), fmap
675
-
676
- class EpochRecorder:
677
- def __init__(self):
678
- self.last_time = ttime()
679
-
680
- def record(self):
681
- now_time = ttime()
682
- elapsed_time = now_time - self.last_time
683
- self.last_time = now_time
684
- return translations["time_or_speed_training"].format(current_time=datetime.datetime.now().strftime("%H:%M:%S"), elapsed_time_str=str(datetime.timedelta(seconds=int(round(elapsed_time, 1)))))
685
-
686
- def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
687
- return torch.log(torch.clamp(x, min=clip_val) * C)
688
-
689
- def dynamic_range_decompression_torch(x, C=1):
690
- return torch.exp(x) / C
691
-
692
- def spectral_normalize_torch(magnitudes):
693
- return dynamic_range_compression_torch(magnitudes)
694
-
695
- def spectral_de_normalize_torch(magnitudes):
696
- return dynamic_range_decompression_torch(magnitudes)
697
-
698
- mel_basis, hann_window = {}, {}
699
-
700
- def spectrogram_torch(y, n_fft, hop_size, win_size, center=False):
701
- global hann_window
702
-
703
- wnsize_dtype_device = str(win_size) + "_" + str(y.dtype) + "_" + str(y.device)
704
- if wnsize_dtype_device not in hann_window: hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
705
- spec = torch.stft(F.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect").squeeze(1), n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], center=center, pad_mode="reflect", normalized=False, onesided=True, return_complex=True)
706
- return torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6)
707
-
708
- def spec_to_mel_torch(spec, n_fft, num_mels, sample_rate, fmin, fmax):
709
- global mel_basis
710
-
711
- fmax_dtype_device = str(fmax) + "_" + str(spec.dtype) + "_" + str(spec.device)
712
- if fmax_dtype_device not in mel_basis: mel_basis[fmax_dtype_device] = torch.from_numpy(librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)).to(dtype=spec.dtype, device=spec.device)
713
- return spectral_normalize_torch(torch.matmul(mel_basis[fmax_dtype_device], spec))
714
-
715
- def mel_spectrogram_torch(y, n_fft, num_mels, sample_rate, hop_size, win_size, fmin, fmax, center=False):
716
- return spec_to_mel_torch(spectrogram_torch(y, n_fft, hop_size, win_size, center), n_fft, num_mels, sample_rate, fmin, fmax)
717
-
718
- def replace_keys_in_dict(d, old_key_part, new_key_part):
719
- updated_dict = OrderedDict() if isinstance(d, OrderedDict) else {}
720
- for key, value in d.items():
721
- updated_dict[(key.replace(old_key_part, new_key_part) if isinstance(key, str) else key)] = (replace_keys_in_dict(value, old_key_part, new_key_part) if isinstance(value, dict) else value)
722
- return updated_dict
723
-
724
- def extract_model(ckpt, sr, pitch_guidance, name, model_path, epoch, step, version, hps, model_author, vocoder):
725
- try:
726
- logger.info(translations["savemodel"].format(model_dir=model_path, epoch=epoch, step=step))
727
- os.makedirs(os.path.dirname(model_path), exist_ok=True)
728
-
729
- opt = OrderedDict(weight={key: value.half() for key, value in ckpt.items() if "enc_q" not in key})
730
- opt["config"] = [hps.data.filter_length // 2 + 1, 32, hps.model.inter_channels, hps.model.hidden_channels, hps.model.filter_channels, hps.model.n_heads, hps.model.n_layers, hps.model.kernel_size, hps.model.p_dropout, hps.model.resblock, hps.model.resblock_kernel_sizes, hps.model.resblock_dilation_sizes, hps.model.upsample_rates, hps.model.upsample_initial_channel, hps.model.upsample_kernel_sizes, hps.model.spk_embed_dim, hps.model.gin_channels, hps.data.sample_rate]
731
- opt["epoch"] = f"{epoch}epoch"
732
- opt["step"] = step
733
- opt["sr"] = sr
734
- opt["f0"] = int(pitch_guidance)
735
- opt["version"] = version
736
- opt["creation_date"] = datetime.datetime.now().isoformat()
737
- opt["model_hash"] = hashlib.sha256(f"{str(ckpt)} {epoch} {step} {datetime.datetime.now().isoformat()}".encode()).hexdigest()
738
- opt["model_name"] = name
739
- opt["author"] = model_author
740
- opt["vocoder"] = vocoder
741
-
742
- torch.save(replace_keys_in_dict(replace_keys_in_dict(opt, ".parametrizations.weight.original1", ".weight_v"), ".parametrizations.weight.original0", ".weight_g"), model_path)
743
- except Exception as e:
744
- logger.error(f"{translations['extract_model_error']}: {e}")
745
-
746
- def run(rank, n_gpus, experiment_dir, pretrainG, pretrainD, pitch_guidance, custom_total_epoch, custom_save_every_weights, config, device, device_id, model_author, vocoder, checkpointing):
747
- global global_step
748
-
749
- if rank == 0: writer_eval = SummaryWriter(log_dir=os.path.join(experiment_dir, "eval"))
750
- else: writer_eval = None
751
-
752
- try:
753
- dist.init_process_group(backend=("gloo" if sys.platform == "win32" or device.type != "cuda" else "nccl"), init_method="env://", world_size=n_gpus, rank=rank)
754
- except:
755
- dist.init_process_group(backend=("gloo" if sys.platform == "win32" or device.type != "cuda" else "nccl"), init_method="env://?use_libuv=False", world_size=n_gpus, rank=rank)
756
-
757
- torch.manual_seed(config.train.seed)
758
- if torch.cuda.is_available(): torch.cuda.set_device(device_id)
759
-
760
- train_dataset = TextAudioLoaderMultiNSFsid(config.data) if pitch_guidance else TextAudioLoader(config.data)
761
- train_loader = tdata.DataLoader(train_dataset, num_workers=4, shuffle=False, pin_memory=True, collate_fn=TextAudioCollateMultiNSFsid() if pitch_guidance else TextAudioCollate(), batch_sampler=DistributedBucketSampler(train_dataset, batch_size * n_gpus, [100, 200, 300, 400, 500, 600, 700, 800, 900], num_replicas=n_gpus, rank=rank, shuffle=True), persistent_workers=True, prefetch_factor=8)
762
-
763
- net_g, net_d = Synthesizer(config.data.filter_length // 2 + 1, config.train.segment_size // config.data.hop_length, **config.model, use_f0=pitch_guidance, sr=sample_rate, vocoder=vocoder, checkpointing=checkpointing), MultiPeriodDiscriminator(version, config.model.use_spectral_norm, checkpointing=checkpointing)
764
- net_g, net_d = (net_g.cuda(device_id), net_d.cuda(device_id)) if torch.cuda.is_available() else (net_g.to(device), net_d.to(device))
765
-
766
- optim_g, optim_d = torch.optim.AdamW(net_g.parameters(), config.train.learning_rate, betas=config.train.betas, eps=config.train.eps), torch.optim.AdamW(net_d.parameters(), config.train.learning_rate, betas=config.train.betas, eps=config.train.eps)
767
- net_g, net_d = (DDP(net_g, device_ids=[device_id]), DDP(net_d, device_ids=[device_id])) if torch.cuda.is_available() else (DDP(net_g), DDP(net_d))
768
-
769
- try:
770
- logger.info(translations["start_training"])
771
- _, _, _, epoch_str = load_checkpoint((os.path.join(experiment_dir, "D_latest.pth") if save_only_latest else latest_checkpoint_path(experiment_dir, "D_*.pth")), net_d, optim_d)
772
- _, _, _, epoch_str = load_checkpoint((os.path.join(experiment_dir, "G_latest.pth") if save_only_latest else latest_checkpoint_path(experiment_dir, "G_*.pth")), net_g, optim_g)
773
- epoch_str += 1
774
- global_step = (epoch_str - 1) * len(train_loader)
775
- except:
776
- epoch_str, global_step = 1, 0
777
-
778
- if pretrainG != "" and pretrainG != "None":
779
- if rank == 0:
780
- verify_checkpoint_shapes(pretrainG, net_g)
781
- logger.info(translations["import_pretrain"].format(dg="G", pretrain=pretrainG))
782
-
783
- if hasattr(net_g, "module"): net_g.module.load_state_dict(torch.load(pretrainG, map_location="cpu")["model"])
784
- else: net_g.load_state_dict(torch.load(pretrainG, map_location="cpu")["model"])
785
- else: logger.warning(translations["not_using_pretrain"].format(dg="G"))
786
-
787
- if pretrainD != "" and pretrainD != "None":
788
- if rank == 0:
789
- verify_checkpoint_shapes(pretrainD, net_d)
790
- logger.info(translations["import_pretrain"].format(dg="D", pretrain=pretrainD))
791
-
792
- if hasattr(net_d, "module"): net_d.module.load_state_dict(torch.load(pretrainD, map_location="cpu")["model"])
793
- else: net_d.load_state_dict(torch.load(pretrainD, map_location="cpu")["model"])
794
- else: logger.warning(translations["not_using_pretrain"].format(dg="D"))
795
-
796
- scheduler_g, scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.train.lr_decay, last_epoch=epoch_str - 2), torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.train.lr_decay, last_epoch=epoch_str - 2)
797
- optim_d.step(); optim_g.step()
798
-
799
- scaler = GradScaler(enabled=main_config.is_half and device.type == "cuda")
800
- cache = []
801
-
802
- for info in train_loader:
803
- phone, phone_lengths, pitch, pitchf, _, _, _, _, sid = info
804
- reference = (phone.cuda(device_id, non_blocking=True), phone_lengths.cuda(device_id, non_blocking=True), (pitch.cuda(device_id, non_blocking=True) if pitch_guidance else None), (pitchf.cuda(device_id, non_blocking=True) if pitch_guidance else None), sid.cuda(device_id, non_blocking=True)) if device.type == "cuda" else (phone.to(device), phone_lengths.to(device), (pitch.to(device) if pitch_guidance else None), (pitchf.to(device) if pitch_guidance else None), sid.to(device))
805
- break
806
-
807
- for epoch in range(epoch_str, total_epoch + 1):
808
- train_and_evaluate(rank, epoch, config, [net_g, net_d], [optim_g, optim_d], scaler, train_loader, writer_eval, cache, custom_save_every_weights, custom_total_epoch, device, device_id, reference, model_author, vocoder)
809
- scheduler_g.step(); scheduler_d.step()
810
-
811
- def train_and_evaluate(rank, epoch, hps, nets, optims, scaler, train_loader, writer, cache, custom_save_every_weights, custom_total_epoch, device, device_id, reference, model_author, vocoder):
812
- global global_step, lowest_value, loss_disc, consecutive_increases_gen, consecutive_increases_disc
813
-
814
- if epoch == 1:
815
- lowest_value = {"step": 0, "value": float("inf"), "epoch": 0}
816
- last_loss_gen_all, consecutive_increases_gen, consecutive_increases_disc = 0.0, 0, 0
817
-
818
- net_g, net_d = nets
819
- optim_g, optim_d = optims
820
- train_loader.batch_sampler.set_epoch(epoch)
821
-
822
- net_g.train(); net_d.train()
823
-
824
- if device.type == "cuda" and cache_data_in_gpu:
825
- data_iterator = cache
826
- if cache == []:
827
- for batch_idx, info in enumerate(train_loader):
828
- cache.append((batch_idx, [tensor.cuda(device_id, non_blocking=True) for tensor in info]))
829
- else: shuffle(cache)
830
- else: data_iterator = enumerate(train_loader)
831
-
832
- epoch_recorder = EpochRecorder()
833
- with tqdm(total=len(train_loader), leave=False) as pbar:
834
- for batch_idx, info in data_iterator:
835
- if device.type == "cuda" and not cache_data_in_gpu: info = [tensor.cuda(device_id, non_blocking=True) for tensor in info]
836
- elif device.type != "cuda": info = [tensor.to(device) for tensor in info]
837
-
838
- phone, phone_lengths, pitch, pitchf, spec, spec_lengths, wave, _, sid = info
839
- pitch = pitch if pitch_guidance else None
840
- pitchf = pitchf if pitch_guidance else None
841
-
842
- with autocast(enabled=main_config.is_half and device.type == "cuda"):
843
- y_hat, ids_slice, _, z_mask, (_, z_p, m_p, logs_p, _, logs_q) = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid)
844
- mel = spec_to_mel_torch(spec, config.data.filter_length, config.data.n_mel_channels, config.data.sample_rate, config.data.mel_fmin, config.data.mel_fmax)
845
- y_mel = slice_segments(mel, ids_slice, config.train.segment_size // config.data.hop_length, dim=3)
846
-
847
- with autocast(enabled=main_config.is_half and device.type == "cuda"):
848
- y_hat_mel = mel_spectrogram_torch(y_hat.float().squeeze(1), config.data.filter_length, config.data.n_mel_channels, config.data.sample_rate, config.data.hop_length, config.data.win_length, config.data.mel_fmin, config.data.mel_fmax)
849
-
850
- wave = slice_segments(wave, ids_slice * config.data.hop_length, config.train.segment_size, dim=3)
851
- y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
852
-
853
- with autocast(enabled=main_config.is_half and device.type == "cuda"):
854
- loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
855
-
856
- optim_d.zero_grad()
857
- scaler.scale(loss_disc).backward()
858
- scaler.unscale_(optim_d)
859
- grad_norm_d = clip_grad_value(net_d.parameters(), None)
860
- scaler.step(optim_d)
861
-
862
- with autocast(enabled=main_config.is_half and device.type == "cuda"):
863
- y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
864
- with autocast(enabled=main_config.is_half and device.type == "cuda"):
865
- loss_mel = F.l1_loss(y_mel, y_hat_mel) * config.train.c_mel
866
- loss_kl = (kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl)
867
- loss_fm = feature_loss(fmap_r, fmap_g)
868
- loss_gen, losses_gen = generator_loss(y_d_hat_g)
869
- loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
870
- if loss_gen_all < lowest_value["value"]:
871
- lowest_value["value"] = loss_gen_all
872
- lowest_value["step"] = global_step
873
- lowest_value["epoch"] = epoch
874
- if epoch > lowest_value["epoch"]: logger.warning(translations["training_warning"])
875
-
876
- optim_g.zero_grad()
877
- scaler.scale(loss_gen_all).backward()
878
- scaler.unscale_(optim_g)
879
- grad_norm_g = clip_grad_value(net_g.parameters(), None)
880
- scaler.step(optim_g)
881
- scaler.update()
882
-
883
- if rank == 0 and global_step % config.train.log_interval == 0:
884
- if loss_mel > 75: loss_mel = 75
885
- if loss_kl > 9: loss_kl = 9
886
-
887
- scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc, "learning_rate": optim_g.param_groups[0]["lr"], "grad/norm_d": grad_norm_d, "grad/norm_g": grad_norm_g, "loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl}
888
- scalar_dict.update({f"loss/g/{i}": v for i, v in enumerate(losses_gen)})
889
- scalar_dict.update({f"loss/d_r/{i}": v for i, v in enumerate(losses_disc_r)})
890
- scalar_dict.update({f"loss/d_g/{i}": v for i, v in enumerate(losses_disc_g)})
891
-
892
- with torch.no_grad():
893
- o, *_ = net_g.module.infer(*reference) if hasattr(net_g, "module") else net_g.infer(*reference)
894
-
895
- summarize(writer=writer, global_step=global_step, images={"slice/mel_org": plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), "slice/mel_gen": plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), "all/mel": plot_spectrogram_to_numpy(mel[0].data.cpu().numpy())}, scalars=scalar_dict, audios={f"gen/audio_{global_step:07d}": o[0, :, :]}, audio_sample_rate=config.data.sample_rate)
896
-
897
- global_step += 1
898
- pbar.update(1)
899
-
900
- def check_overtraining(smoothed_loss_history, threshold, epsilon=0.004):
901
- if len(smoothed_loss_history) < threshold + 1: return False
902
- for i in range(-threshold, -1):
903
- if smoothed_loss_history[i + 1] > smoothed_loss_history[i]: return True
904
- if abs(smoothed_loss_history[i + 1] - smoothed_loss_history[i]) >= epsilon: return False
905
- return True
906
-
907
- def update_exponential_moving_average(smoothed_loss_history, new_value, smoothing=0.987):
908
- smoothed_value = new_value if not smoothed_loss_history else (smoothing * smoothed_loss_history[-1] + (1 - smoothing) * new_value)
909
- smoothed_loss_history.append(smoothed_value)
910
- return smoothed_value
911
-
912
- def save_to_json(file_path, loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history):
913
- with open(file_path, "w") as f:
914
- json.dump({"loss_disc_history": loss_disc_history, "smoothed_loss_disc_history": smoothed_loss_disc_history, "loss_gen_history": loss_gen_history, "smoothed_loss_gen_history": smoothed_loss_gen_history}, f)
915
-
916
- model_add, model_del = [], []
917
- done = False
918
-
919
- if rank == 0:
920
- if epoch % save_every_epoch == False:
921
- checkpoint_suffix = f"{'latest' if save_only_latest else global_step}.pth"
922
- save_checkpoint(net_g, optim_g, config.train.learning_rate, epoch, os.path.join(experiment_dir, "G_" + checkpoint_suffix))
923
- save_checkpoint(net_d, optim_d, config.train.learning_rate, epoch, os.path.join(experiment_dir, "D_" + checkpoint_suffix))
924
- if custom_save_every_weights: model_add.append(os.path.join("assets", "weights", f"{model_name}_{epoch}e_{global_step}s.pth"))
925
-
926
- if overtraining_detector and epoch > 1:
927
- current_loss_disc = float(loss_disc)
928
- loss_disc_history.append(current_loss_disc)
929
- smoothed_value_disc = update_exponential_moving_average(smoothed_loss_disc_history, current_loss_disc)
930
- is_overtraining_disc = check_overtraining(smoothed_loss_disc_history, overtraining_threshold * 2)
931
-
932
- if is_overtraining_disc: consecutive_increases_disc += 1
933
- else: consecutive_increases_disc = 0
934
-
935
- current_loss_gen = float(lowest_value["value"])
936
- loss_gen_history.append(current_loss_gen)
937
- smoothed_value_gen = update_exponential_moving_average(smoothed_loss_gen_history, current_loss_gen)
938
- is_overtraining_gen = check_overtraining(smoothed_loss_gen_history, overtraining_threshold, 0.01)
939
-
940
- if is_overtraining_gen: consecutive_increases_gen += 1
941
- else: consecutive_increases_gen = 0
942
-
943
- if epoch % save_every_epoch == 0: save_to_json(training_file_path, loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history)
944
-
945
- if (is_overtraining_gen and consecutive_increases_gen == overtraining_threshold or is_overtraining_disc and consecutive_increases_disc == (overtraining_threshold * 2)):
946
- logger.info(translations["overtraining_find"].format(epoch=epoch, smoothed_value_gen=f"{smoothed_value_gen:.3f}", smoothed_value_disc=f"{smoothed_value_disc:.3f}"))
947
- done = True
948
- else:
949
- logger.info(translations["best_epoch"].format(epoch=epoch, smoothed_value_gen=f"{smoothed_value_gen:.3f}", smoothed_value_disc=f"{smoothed_value_disc:.3f}"))
950
- for file in glob.glob(os.path.join("assets", "weights", f"{model_name}_*e_*s_best_epoch.pth")):
951
- model_del.append(file)
952
-
953
- model_add.append(os.path.join("assets", "weights", f"{model_name}_{epoch}e_{global_step}s_best_epoch.pth"))
954
-
955
- if epoch >= custom_total_epoch:
956
- logger.info(translations["success_training"].format(epoch=epoch, global_step=global_step, loss_gen_all=round(loss_gen_all.item(), 3)))
957
- logger.info(translations["training_info"].format(lowest_value_rounded=round(float(lowest_value["value"]), 3), lowest_value_epoch=lowest_value['epoch'], lowest_value_step=lowest_value['step']))
958
-
959
- pid_file_path = os.path.join(experiment_dir, "config.json")
960
- with open(pid_file_path, "r") as pid_file:
961
- pid_data = json.load(pid_file)
962
-
963
- with open(pid_file_path, "w") as pid_file:
964
- pid_data.pop("process_pids", None)
965
- json.dump(pid_data, pid_file, indent=4)
966
-
967
- if os.path.exists(os.path.join(experiment_dir, "train_pid.txt")): os.remove(os.path.join(experiment_dir, "train_pid.txt"))
968
- model_add.append(os.path.join("assets", "weights", f"{model_name}_{epoch}e_{global_step}s.pth"))
969
- done = True
970
-
971
- for m in model_del:
972
- os.remove(m)
973
-
974
- if model_add:
975
- ckpt = (net_g.module.state_dict() if hasattr(net_g, "module") else net_g.state_dict())
976
- for m in model_add:
977
- extract_model(ckpt=ckpt, sr=sample_rate, pitch_guidance=pitch_guidance == True, name=model_name, model_path=m, epoch=epoch, step=global_step, version=version, hps=hps, model_author=model_author, vocoder=vocoder)
978
-
979
- lowest_value_rounded = round(float(lowest_value["value"]), 3)
980
-
981
- if epoch > 1 and overtraining_detector: logger.info(translations["model_training_info"].format(model_name=model_name, epoch=epoch, global_step=global_step, epoch_recorder=epoch_recorder.record(), lowest_value_rounded=lowest_value_rounded, lowest_value_epoch=lowest_value['epoch'], lowest_value_step=lowest_value['step'], remaining_epochs_gen=(overtraining_threshold - consecutive_increases_gen), remaining_epochs_disc=((overtraining_threshold * 2) - consecutive_increases_disc), smoothed_value_gen=f"{smoothed_value_gen:.3f}", smoothed_value_disc=f"{smoothed_value_disc:.3f}"))
982
- elif epoch > 1 and overtraining_detector == False: logger.info(translations["model_training_info_2"].format(model_name=model_name, epoch=epoch, global_step=global_step, epoch_recorder=epoch_recorder.record(), lowest_value_rounded=lowest_value_rounded, lowest_value_epoch=lowest_value['epoch'], lowest_value_step=lowest_value['step']))
983
- else: logger.info(translations["model_training_info_3"].format(model_name=model_name, epoch=epoch, global_step=global_step, epoch_recorder=epoch_recorder.record()))
984
-
985
- last_loss_gen_all = loss_gen_all
986
- if done: os._exit(0)
987
-
988
- if __name__ == "__main__":
989
- torch.multiprocessing.set_start_method("spawn")
990
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/algorithm/commons.py DELETED
@@ -1,60 +0,0 @@
1
- import torch
2
-
3
-
4
-
5
- def init_weights(m, mean=0.0, std=0.01):
6
- if m.__class__.__name__.find("Conv") != -1: m.weight.data.normal_(mean, std)
7
-
8
- def get_padding(kernel_size, dilation=1):
9
- return int((kernel_size * dilation - dilation) / 2)
10
-
11
- def convert_pad_shape(pad_shape):
12
- return [item for sublist in pad_shape[::-1] for item in sublist]
13
-
14
- def slice_segments(x, ids_str, segment_size = 4, dim = 2):
15
- if dim == 2: ret = torch.zeros_like(x[:, :segment_size])
16
- elif dim == 3: ret = torch.zeros_like(x[:, :, :segment_size])
17
-
18
- for i in range(x.size(0)):
19
- idx_str = ids_str[i].item()
20
- idx_end = idx_str + segment_size
21
-
22
- if dim == 2: ret[i] = x[i, idx_str:idx_end]
23
- else: ret[i] = x[i, :, idx_str:idx_end]
24
-
25
- return ret
26
-
27
- def rand_slice_segments(x, x_lengths=None, segment_size=4):
28
- b, _, t = x.size()
29
- if x_lengths is None: x_lengths = t
30
-
31
- ids_str = (torch.rand([b]).to(device=x.device) * (x_lengths - segment_size + 1)).to(dtype=torch.long)
32
-
33
- return slice_segments(x, ids_str, segment_size, dim=3), ids_str
34
-
35
- @torch.jit.script
36
- def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
37
- n_channels_int = n_channels[0]
38
-
39
- in_act = input_a + input_b
40
-
41
- return torch.tanh(in_act[:, :n_channels_int, :]) * torch.sigmoid(in_act[:, n_channels_int:, :])
42
-
43
- def sequence_mask(length, max_length = None):
44
- if max_length is None: max_length = length.max()
45
-
46
- return torch.arange(max_length, dtype=length.dtype, device=length.device).unsqueeze(0) < length.unsqueeze(1)
47
-
48
- def clip_grad_value(parameters, clip_value, norm_type=2):
49
- if isinstance(parameters, torch.Tensor): parameters = [parameters]
50
- norm_type = float(norm_type)
51
-
52
- if clip_value is not None: clip_value = float(clip_value)
53
- total_norm = 0
54
-
55
- for p in list(filter(lambda p: p.grad is not None, parameters)):
56
- total_norm += (p.grad.data.norm(norm_type)).item() ** norm_type
57
-
58
- if clip_value is not None: p.grad.data.clamp_(min=-clip_value, max=clip_value)
59
-
60
- return total_norm ** (1.0 / norm_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/algorithm/modules.py DELETED
@@ -1,60 +0,0 @@
1
- import os
2
- import sys
3
- import torch
4
-
5
- sys.path.append(os.getcwd())
6
-
7
- from .commons import fused_add_tanh_sigmoid_multiply
8
-
9
- class WaveNet(torch.nn.Module):
10
- def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
11
- super(WaveNet, self).__init__()
12
- assert kernel_size % 2 == 1
13
- self.hidden_channels = hidden_channels
14
- self.kernel_size = (kernel_size,)
15
- self.dilation_rate = dilation_rate
16
- self.n_layers = n_layers
17
- self.gin_channels = gin_channels
18
- self.p_dropout = p_dropout
19
- self.in_layers = torch.nn.ModuleList()
20
- self.res_skip_layers = torch.nn.ModuleList()
21
- self.drop = torch.nn.Dropout(p_dropout)
22
- if gin_channels != 0: self.cond_layer = torch.nn.utils.parametrizations.weight_norm(torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1), name="weight")
23
- dilations = [dilation_rate ** i for i in range(n_layers)]
24
- paddings = [(kernel_size * d - d) // 2 for d in dilations]
25
-
26
- for i in range(n_layers):
27
- in_layer = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilations[i], padding=paddings[i])
28
- in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
29
- self.in_layers.append(in_layer)
30
- res_skip_channels = (hidden_channels if i == n_layers - 1 else 2 * hidden_channels)
31
- res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
32
- res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
33
- self.res_skip_layers.append(res_skip_layer)
34
-
35
- def forward(self, x, x_mask, g=None):
36
- output = x.clone().zero_()
37
- n_channels_tensor = torch.IntTensor([self.hidden_channels])
38
-
39
- if g is not None: g = self.cond_layer(g)
40
-
41
- for i in range(self.n_layers):
42
- x_in = self.in_layers[i](x)
43
- g_l = (g[:, i * 2 * self.hidden_channels : (i + 1) * 2 * self.hidden_channels, :] if g is not None else 0)
44
- res_skip_acts = self.res_skip_layers[i](self.drop(fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)))
45
-
46
- if i < self.n_layers - 1:
47
- x = (x + (res_skip_acts[:, : self.hidden_channels, :])) * x_mask
48
- output = output + res_skip_acts[:, self.hidden_channels :, :]
49
- else: output = output + res_skip_acts
50
-
51
- return output * x_mask
52
-
53
- def remove_weight_norm(self):
54
- if self.gin_channels != 0: torch.nn.utils.remove_weight_norm(self.cond_layer)
55
-
56
- for l in self.in_layers:
57
- torch.nn.utils.remove_weight_norm(l)
58
-
59
- for l in self.res_skip_layers:
60
- torch.nn.utils.remove_weight_norm(l)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/algorithm/mrf_hifigan.py DELETED
@@ -1,150 +0,0 @@
1
- import math
2
- import torch
3
-
4
- import numpy as np
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
-
8
- from torch.nn.utils import remove_weight_norm
9
- from torch.utils.checkpoint import checkpoint
10
- from torch.nn.utils.parametrizations import weight_norm
11
-
12
- LRELU_SLOPE = 0.1
13
-
14
- class MRFLayer(nn.Module):
15
- def __init__(self, channels, kernel_size, dilation):
16
- super().__init__()
17
- self.conv1 = weight_norm(nn.Conv1d(channels, channels, kernel_size, padding=(kernel_size * dilation - dilation) // 2, dilation=dilation))
18
- self.conv2 = weight_norm(nn.Conv1d(channels, channels, kernel_size, padding=kernel_size // 2, dilation=1))
19
-
20
- def forward(self, x):
21
- return x + self.conv2(F.leaky_relu(self.conv1(F.leaky_relu(x, LRELU_SLOPE)), LRELU_SLOPE))
22
-
23
- def remove_weight_norm(self):
24
- remove_weight_norm(self.conv1)
25
- remove_weight_norm(self.conv2)
26
-
27
- class MRFBlock(nn.Module):
28
- def __init__(self, channels, kernel_size, dilations):
29
- super().__init__()
30
- self.layers = nn.ModuleList()
31
-
32
- for dilation in dilations:
33
- self.layers.append(MRFLayer(channels, kernel_size, dilation))
34
-
35
- def forward(self, x):
36
- for layer in self.layers:
37
- x = layer(x)
38
-
39
- return x
40
-
41
- def remove_weight_norm(self):
42
- for layer in self.layers:
43
- layer.remove_weight_norm()
44
-
45
- class SineGenerator(nn.Module):
46
- def __init__(self, samp_rate, harmonic_num = 0, sine_amp = 0.1, noise_std = 0.003, voiced_threshold = 0):
47
- super(SineGenerator, self).__init__()
48
- self.sine_amp = sine_amp
49
- self.noise_std = noise_std
50
- self.harmonic_num = harmonic_num
51
- self.dim = self.harmonic_num + 1
52
- self.sampling_rate = samp_rate
53
- self.voiced_threshold = voiced_threshold
54
-
55
- def _f02uv(self, f0):
56
- return torch.ones_like(f0) * (f0 > self.voiced_threshold)
57
-
58
- def _f02sine(self, f0_values):
59
- rad_values = (f0_values / self.sampling_rate) % 1
60
- rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], dtype=f0_values.dtype, device=f0_values.device)
61
- rand_ini[:, 0] = 0
62
- rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
63
- tmp_over_one = torch.cumsum(rad_values, 1) % 1
64
- tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
65
- cumsum_shift = torch.zeros_like(rad_values)
66
- cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
67
-
68
- return torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi)
69
-
70
- def forward(self, f0):
71
- with torch.no_grad():
72
- f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, dtype=f0.dtype, device=f0.device)
73
- f0_buf[:, :, 0] = f0[:, :, 0]
74
-
75
- for idx in np.arange(self.harmonic_num):
76
- f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
77
-
78
- sine_waves = self._f02sine(f0_buf) * self.sine_amp
79
- uv = self._f02uv(f0)
80
- sine_waves = sine_waves * uv + ((uv * self.noise_std + (1 - uv) * self.sine_amp / 3) * torch.randn_like(sine_waves))
81
-
82
- return sine_waves
83
-
84
- class SourceModuleHnNSF(nn.Module):
85
- def __init__(self, sampling_rate, harmonic_num = 0, sine_amp = 0.1, add_noise_std = 0.003, voiced_threshold = 0):
86
- super(SourceModuleHnNSF, self).__init__()
87
- self.sine_amp = sine_amp
88
- self.noise_std = add_noise_std
89
- self.l_sin_gen = SineGenerator(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshold)
90
- self.l_linear = nn.Linear(harmonic_num + 1, 1)
91
- self.l_tanh = nn.Tanh()
92
-
93
- def forward(self, x):
94
- return self.l_tanh(self.l_linear(self.l_sin_gen(x).to(dtype=self.l_linear.weight.dtype)))
95
-
96
- class HiFiGANMRFGenerator(nn.Module):
97
- def __init__(self, in_channel, upsample_initial_channel, upsample_rates, upsample_kernel_sizes, resblock_kernel_sizes, resblock_dilations, gin_channels, sample_rate, harmonic_num, checkpointing = False):
98
- super().__init__()
99
- self.num_kernels = len(resblock_kernel_sizes)
100
- self.checkpointing = checkpointing
101
- self.f0_upsample = nn.Upsample(scale_factor=np.prod(upsample_rates))
102
- self.m_source = SourceModuleHnNSF(sample_rate, harmonic_num)
103
- self.conv_pre = weight_norm(nn.Conv1d(in_channel, upsample_initial_channel, kernel_size=7, stride=1, padding=3))
104
- self.upsamples = nn.ModuleList()
105
- self.noise_convs = nn.ModuleList()
106
- stride_f0s = [math.prod(upsample_rates[i + 1 :]) if i + 1 < len(upsample_rates) else 1 for i in range(len(upsample_rates))]
107
-
108
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
109
- self.upsamples.append(weight_norm(nn.ConvTranspose1d(upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), kernel_size=k, stride=u, padding=((k - u) // 2) if u % 2 == 0 else (u // 2 + u % 2), output_padding=u % 2)))
110
- stride = stride_f0s[i]
111
- kernel = 1 if stride == 1 else stride * 2 - stride % 2
112
- self.noise_convs.append(nn.Conv1d(1, upsample_initial_channel // (2 ** (i + 1)), kernel_size=kernel, stride=stride, padding=0 if stride == 1 else (kernel - stride) // 2))
113
-
114
- self.mrfs = nn.ModuleList()
115
- for i in range(len(self.upsamples)):
116
- channel = upsample_initial_channel // (2 ** (i + 1))
117
- self.mrfs.append(nn.ModuleList([MRFBlock(channel, kernel_size=k, dilations=d) for k, d in zip(resblock_kernel_sizes, resblock_dilations)]))
118
-
119
- self.conv_post = weight_norm(nn.Conv1d(channel, 1, kernel_size=7, stride=1, padding=3))
120
- if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
121
-
122
- def forward(self, x, f0, g = None):
123
- har_source = self.m_source(self.f0_upsample(f0[:, None, :]).transpose(-1, -2)).transpose(-1, -2)
124
- x = self.conv_pre(x)
125
- if g is not None: x += self.cond(g)
126
-
127
- for ups, mrf, noise_conv in zip(self.upsamples, self.mrfs, self.noise_convs):
128
- x = F.leaky_relu(x, LRELU_SLOPE)
129
-
130
- if self.training and self.checkpointing:
131
- x = checkpoint(ups, x, use_reentrant=False) + noise_conv(har_source)
132
- xs = sum([checkpoint(layer, x, use_reentrant=False) for layer in mrf])
133
- else:
134
- x = ups(x) + noise_conv(har_source)
135
- xs = sum([layer(x) for layer in mrf])
136
-
137
- x = xs / self.num_kernels
138
-
139
- return torch.tanh(self.conv_post(F.leaky_relu(x)))
140
-
141
- def remove_weight_norm(self):
142
- remove_weight_norm(self.conv_pre)
143
-
144
- for up in self.upsamples:
145
- remove_weight_norm(up)
146
-
147
- for mrf in self.mrfs:
148
- mrf.remove_weight_norm()
149
-
150
- remove_weight_norm(self.conv_post)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/algorithm/onnx_export.py DELETED
@@ -1,50 +0,0 @@
1
- import os
2
- import io
3
- import sys
4
- import onnx
5
- import json
6
- import torch
7
- import onnxsim
8
- import warnings
9
-
10
- sys.path.append(os.getcwd())
11
-
12
- from main.library.algorithm.synthesizers import SynthesizerONNX
13
-
14
- warnings.filterwarnings("ignore")
15
-
16
- def onnx_exporter(input_path, output_path, is_half=False, device="cpu"):
17
- cpt = (torch.load(input_path, map_location="cpu") if os.path.isfile(input_path) else None)
18
- cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
19
-
20
- model_name, model_author, epochs, steps, version, f0, model_hash, vocoder, creation_date = cpt.get("model_name", None), cpt.get("author", None), cpt.get("epoch", None), cpt.get("step", None), cpt.get("version", "v1"), cpt.get("f0", 1), cpt.get("model_hash", None), cpt.get("vocoder", "Default"), cpt.get("creation_date", None)
21
- text_enc_hidden_dim = 768 if version == "v2" else 256
22
- tgt_sr = cpt["config"][-1]
23
-
24
- net_g = SynthesizerONNX(*cpt["config"], use_f0=f0, text_enc_hidden_dim=text_enc_hidden_dim, vocoder=vocoder, checkpointing=False)
25
- net_g.load_state_dict(cpt["weight"], strict=False)
26
- net_g.eval().to(device)
27
- net_g = (net_g.half() if is_half else net_g.float())
28
-
29
- phone = torch.rand(1, 200, text_enc_hidden_dim).to(device)
30
- phone_length = torch.tensor([200]).long().to(device)
31
- ds = torch.LongTensor([0]).to(device)
32
- rnd = torch.rand(1, 192, 200).to(device)
33
-
34
- if f0:
35
- args = (phone, phone_length, ds, rnd, torch.randint(size=(1, 200), low=5, high=255).to(device), torch.rand(1, 200).to(device))
36
- input_names = ["phone", "phone_lengths", "ds", "rnd", "pitch", "pitchf"]
37
- dynamic_axes = {"phone": [1], "rnd": [2], "pitch": [1], "pitchf": [1]}
38
- else:
39
- args = (phone, phone_length, ds, rnd)
40
- input_names = ["phone", "phone_lengths", "ds", "rnd"]
41
- dynamic_axes = {"phone": [1], "rnd": [2]}
42
-
43
- with io.BytesIO() as model:
44
- torch.onnx.export(net_g, args, model, do_constant_folding=True, opset_version=17, verbose=False, input_names=input_names, output_names=["audio"], dynamic_axes=dynamic_axes)
45
-
46
- model, _ = onnxsim.simplify(onnx.load_model_from_string(model.getvalue()))
47
- model.metadata_props.append(onnx.StringStringEntryProto(key="model_info", value=json.dumps({"model_name": model_name, "author": model_author, "epoch": epochs, "step": steps, "version": version, "sr": tgt_sr, "f0": f0, "model_hash": model_hash, "creation_date": creation_date, "vocoder": vocoder, "text_enc_hidden_dim": text_enc_hidden_dim})))
48
-
49
- onnx.save(model, output_path)
50
- return output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/algorithm/refinegan.py DELETED
@@ -1,170 +0,0 @@
1
- import os
2
- import sys
3
- import math
4
- import torch
5
-
6
- import numpy as np
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
-
10
- from torch.utils.checkpoint import checkpoint
11
- from torch.nn.utils import remove_weight_norm
12
- from torch.nn.utils.parametrizations import weight_norm
13
-
14
- sys.path.append(os.getcwd())
15
-
16
- from main.library.algorithm.commons import init_weights, get_padding
17
-
18
-
19
- class ResBlock(nn.Module):
20
- def __init__(self, channels, kernel_size = 7, dilation = (1, 3, 5), leaky_relu_slope = 0.2):
21
- super().__init__()
22
- self.leaky_relu_slope = leaky_relu_slope
23
- self.convs1 = nn.ModuleList([weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=d, padding=get_padding(kernel_size, d))) for d in dilation])
24
- self.convs1.apply(init_weights)
25
- self.convs2 = nn.ModuleList([weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=1, padding=get_padding(kernel_size, 1))) for _ in dilation])
26
- self.convs2.apply(init_weights)
27
-
28
- def forward(self, x):
29
- for c1, c2 in zip(self.convs1, self.convs2):
30
- x = c2(F.leaky_relu(c1(F.leaky_relu(x, self.leaky_relu_slope)), self.leaky_relu_slope)) + x
31
-
32
- return x
33
-
34
- def remove_weight_norm(self):
35
- for c1, c2 in zip(self.convs1, self.convs2):
36
- remove_weight_norm(c1)
37
- remove_weight_norm(c2)
38
-
39
- class AdaIN(nn.Module):
40
- def __init__(self, *, channels, leaky_relu_slope = 0.2):
41
- super().__init__()
42
- self.weight = nn.Parameter(torch.ones(channels))
43
- self.activation = nn.LeakyReLU(leaky_relu_slope)
44
-
45
- def forward(self, x):
46
- return self.activation(x + (torch.randn_like(x) * self.weight[None, :, None]))
47
-
48
- class ParallelResBlock(nn.Module):
49
- def __init__(self, *, in_channels, out_channels, kernel_sizes = (3, 7, 11), dilation = (1, 3, 5), leaky_relu_slope = 0.2):
50
- super().__init__()
51
- self.in_channels = in_channels
52
- self.out_channels = out_channels
53
- self.input_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=7, stride=1, padding=3)
54
- self.input_conv.apply(init_weights)
55
- self.blocks = nn.ModuleList([nn.Sequential(AdaIN(channels=out_channels), ResBlock(out_channels, kernel_size=kernel_size, dilation=dilation, leaky_relu_slope=leaky_relu_slope), AdaIN(channels=out_channels)) for kernel_size in kernel_sizes])
56
-
57
- def forward(self, x):
58
- x = self.input_conv(x)
59
- return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
60
-
61
- def remove_weight_norm(self):
62
- remove_weight_norm(self.input_conv)
63
- for block in self.blocks:
64
- block[1].remove_weight_norm()
65
-
66
- class SineGenerator(nn.Module):
67
- def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0):
68
- super(SineGenerator, self).__init__()
69
- self.sine_amp = sine_amp
70
- self.noise_std = noise_std
71
- self.harmonic_num = harmonic_num
72
- self.dim = self.harmonic_num + 1
73
- self.sampling_rate = samp_rate
74
- self.voiced_threshold = voiced_threshold
75
- self.merge = nn.Sequential(nn.Linear(self.dim, 1, bias=False), nn.Tanh())
76
-
77
- def _f02uv(self, f0):
78
- return torch.ones_like(f0) * (f0 > self.voiced_threshold)
79
-
80
- def _f02sine(self, f0_values):
81
- rad_values = (f0_values / self.sampling_rate) % 1
82
- rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], dtype=f0_values.dtype, device=f0_values.device)
83
-
84
- rand_ini[:, 0] = 0
85
- rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
86
-
87
- tmp_over_one = torch.cumsum(rad_values, 1) % 1
88
- tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
89
-
90
- cumsum_shift = torch.zeros_like(rad_values)
91
- cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
92
-
93
- return torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi)
94
-
95
- def forward(self, f0):
96
- with torch.no_grad():
97
- f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, dtype=f0.dtype, device=f0.device)
98
- f0_buf[:, :, 0] = f0[:, :, 0]
99
-
100
- for idx in np.arange(self.harmonic_num):
101
- f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
102
-
103
- sine_waves = self._f02sine(f0_buf) * self.sine_amp
104
- uv = self._f02uv(f0)
105
- sine_waves = sine_waves * uv + ((uv * self.noise_std + (1 - uv) * self.sine_amp / 3) * torch.randn_like(sine_waves))
106
-
107
- return self.merge(sine_waves)
108
-
109
- class RefineGANGenerator(nn.Module):
110
- def __init__(self, *, sample_rate = 44100, upsample_rates = (8, 8, 2, 2), leaky_relu_slope = 0.2, num_mels = 128, gin_channels = 256, checkpointing = False, upsample_initial_channel = 512):
111
- super().__init__()
112
- self.upsample_rates = upsample_rates
113
- self.checkpointing = checkpointing
114
- self.leaky_relu_slope = leaky_relu_slope
115
- self.upp = np.prod(upsample_rates)
116
- self.m_source = SineGenerator(sample_rate)
117
- self.pre_conv = weight_norm(nn.Conv1d(1, upsample_initial_channel // 2, 7, 1, padding=3))
118
- stride_f0s = [math.prod(upsample_rates[i + 1 :]) if i + 1 < len(upsample_rates) else 1 for i in range(len(upsample_rates))]
119
-
120
- channels = upsample_initial_channel
121
- self.downsample_blocks = nn.ModuleList([])
122
-
123
- for i, _ in enumerate(upsample_rates):
124
- stride = stride_f0s[i]
125
- kernel = 1 if stride == 1 else stride * 2 - stride % 2
126
-
127
- self.downsample_blocks.append(weight_norm(nn.Conv1d(1, channels // 2 ** (i + 2), kernel, stride, padding=0 if stride == 1 else (kernel - stride) // 2)))
128
-
129
- self.mel_conv = weight_norm(nn.Conv1d(num_mels, channels // 2, 7, 1, padding=3))
130
- self.mel_conv.apply(init_weights)
131
-
132
- if gin_channels != 0: self.cond = nn.Conv1d(256, channels // 2, 1)
133
-
134
- self.upsample_blocks = nn.ModuleList([])
135
- self.upsample_conv_blocks = nn.ModuleList([])
136
-
137
- for rate in upsample_rates:
138
- new_channels = channels // 2
139
- self.upsample_blocks.append(nn.Upsample(scale_factor=rate, mode="linear"))
140
- self.upsample_conv_blocks.append(ParallelResBlock(in_channels=channels + channels // 4, out_channels=new_channels, kernel_sizes=(3, 7, 11), dilation=(1, 3, 5), leaky_relu_slope=leaky_relu_slope))
141
- channels = new_channels
142
-
143
- self.conv_post = weight_norm(nn.Conv1d(channels, 1, 7, 1, padding=3, bias=False))
144
- self.conv_post.apply(init_weights)
145
-
146
- def forward(self, mel, f0, g = None):
147
- har_source = self.m_source(F.interpolate(f0.unsqueeze(1), size=mel.shape[-1] * self.upp, mode="linear").transpose(1, 2)).transpose(1, 2)
148
- x = F.interpolate(self.pre_conv(har_source), size=mel.shape[-1], mode="linear")
149
-
150
- mel = self.mel_conv(mel)
151
- if g is not None: mel += self.cond(g)
152
-
153
- x = torch.cat([mel, x], dim=1)
154
-
155
- for ups, res, down in zip(self.upsample_blocks, self.upsample_conv_blocks, self.downsample_blocks):
156
- x = F.leaky_relu(x, self.leaky_relu_slope)
157
- x = checkpoint(res, torch.cat([checkpoint(ups, x, use_reentrant=False), down(har_source)], dim=1), use_reentrant=False) if self.training and self.checkpointing else res(torch.cat([ups(x), down(har_source)], dim=1))
158
-
159
- return torch.tanh(self.conv_post(F.leaky_relu(x, self.leaky_relu_slope)))
160
-
161
- def remove_weight_norm(self):
162
- remove_weight_norm(self.pre_conv)
163
- remove_weight_norm(self.mel_conv)
164
- remove_weight_norm(self.conv_post)
165
-
166
- for block in self.downsample_blocks:
167
- block.remove_weight_norm()
168
-
169
- for block in self.upsample_conv_blocks:
170
- block.remove_weight_norm()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/algorithm/residuals.py DELETED
@@ -1,140 +0,0 @@
1
- import os
2
- import sys
3
- import torch
4
-
5
- from torch.nn.utils import remove_weight_norm
6
- from torch.nn.utils.parametrizations import weight_norm
7
-
8
- sys.path.append(os.getcwd())
9
-
10
- from .modules import WaveNet
11
- from .commons import get_padding, init_weights
12
-
13
-
14
- LRELU_SLOPE = 0.1
15
-
16
- def create_conv1d_layer(channels, kernel_size, dilation):
17
- return weight_norm(torch.nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation, padding=get_padding(kernel_size, dilation)))
18
-
19
- def apply_mask(tensor, mask):
20
- return tensor * mask if mask is not None else tensor
21
-
22
- class ResBlockBase(torch.nn.Module):
23
- def __init__(self, channels, kernel_size, dilations):
24
- super(ResBlockBase, self).__init__()
25
-
26
- self.convs1 = torch.nn.ModuleList([create_conv1d_layer(channels, kernel_size, d) for d in dilations])
27
- self.convs1.apply(init_weights)
28
-
29
- self.convs2 = torch.nn.ModuleList([create_conv1d_layer(channels, kernel_size, 1) for _ in dilations])
30
- self.convs2.apply(init_weights)
31
-
32
- def forward(self, x, x_mask=None):
33
- for c1, c2 in zip(self.convs1, self.convs2):
34
- x = c2(apply_mask(torch.nn.functional.leaky_relu(c1(apply_mask(torch.nn.functional.leaky_relu(x, LRELU_SLOPE), x_mask)), LRELU_SLOPE), x_mask)) + x
35
-
36
- return apply_mask(x, x_mask)
37
-
38
- def remove_weight_norm(self):
39
- for conv in self.convs1 + self.convs2:
40
- remove_weight_norm(conv)
41
-
42
- class ResBlock(ResBlockBase):
43
- def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
44
- super(ResBlock, self).__init__(channels, kernel_size, dilation)
45
-
46
- class Log(torch.nn.Module):
47
- def forward(self, x, x_mask, reverse=False, **kwargs):
48
- if not reverse:
49
- y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
50
- return y, torch.sum(-y, [1, 2])
51
- else: return torch.exp(x) * x_mask
52
-
53
- class Flip(torch.nn.Module):
54
- def forward(self, x, *args, reverse=False, **kwargs):
55
- x = torch.flip(x, [1])
56
-
57
- if not reverse: return x, torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
58
- else: return x
59
-
60
- class ElementwiseAffine(torch.nn.Module):
61
- def __init__(self, channels):
62
- super().__init__()
63
- self.channels = channels
64
- self.m = torch.nn.Parameter(torch.zeros(channels, 1))
65
- self.logs = torch.nn.Parameter(torch.zeros(channels, 1))
66
-
67
- def forward(self, x, x_mask, reverse=False, **kwargs):
68
- if not reverse: return ((self.m + torch.exp(self.logs) * x) * x_mask), torch.sum(self.logs * x_mask, [1, 2])
69
- else: return (x - self.m) * torch.exp(-self.logs) * x_mask
70
-
71
- class ResidualCouplingBlock(torch.nn.Module):
72
- def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0):
73
- super(ResidualCouplingBlock, self).__init__()
74
- self.channels = channels
75
- self.hidden_channels = hidden_channels
76
- self.kernel_size = kernel_size
77
- self.dilation_rate = dilation_rate
78
- self.n_layers = n_layers
79
- self.n_flows = n_flows
80
- self.gin_channels = gin_channels
81
- self.flows = torch.nn.ModuleList()
82
-
83
- for _ in range(n_flows):
84
- self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
85
- self.flows.append(Flip())
86
-
87
- def forward(self, x, x_mask, g = None, reverse = False):
88
- if not reverse:
89
- for flow in self.flows:
90
- x, _ = flow(x, x_mask, g=g, reverse=reverse)
91
- else:
92
- for flow in reversed(self.flows):
93
- x = flow.forward(x, x_mask, g=g, reverse=reverse)
94
-
95
- return x
96
-
97
- def remove_weight_norm(self):
98
- for i in range(self.n_flows):
99
- self.flows[i * 2].remove_weight_norm()
100
-
101
- def __prepare_scriptable__(self):
102
- for i in range(self.n_flows):
103
- for hook in self.flows[i * 2]._forward_pre_hooks.values():
104
- if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(self.flows[i * 2])
105
-
106
- return self
107
-
108
- class ResidualCouplingLayer(torch.nn.Module):
109
- def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=0, mean_only=False):
110
- assert channels % 2 == 0, "Channels/2"
111
- super().__init__()
112
- self.channels = channels
113
- self.hidden_channels = hidden_channels
114
- self.kernel_size = kernel_size
115
- self.dilation_rate = dilation_rate
116
- self.n_layers = n_layers
117
- self.half_channels = channels // 2
118
- self.mean_only = mean_only
119
-
120
- self.pre = torch.nn.Conv1d(self.half_channels, hidden_channels, 1)
121
- self.enc = WaveNet(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
122
- self.post = torch.nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
123
-
124
- self.post.weight.data.zero_()
125
- self.post.bias.data.zero_()
126
-
127
- def forward(self, x, x_mask, g=None, reverse=False):
128
- x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
129
- stats = self.post(self.enc((self.pre(x0) * x_mask), x_mask, g=g)) * x_mask
130
-
131
- if not self.mean_only: m, logs = torch.split(stats, [self.half_channels] * 2, 1)
132
- else:
133
- m = stats
134
- logs = torch.zeros_like(m)
135
-
136
- if not reverse: return torch.cat([x0, (m + x1 * torch.exp(logs) * x_mask)], 1), torch.sum(logs, [1, 2])
137
- else: return torch.cat([x0, ((x1 - m) * torch.exp(-logs) * x_mask)], 1)
138
-
139
- def remove_weight_norm(self):
140
- self.enc.remove_weight_norm()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/algorithm/separator.py DELETED
@@ -1,320 +0,0 @@
1
- import os
2
- import sys
3
- import time
4
- import yaml
5
- import torch
6
- import codecs
7
- import hashlib
8
- import logging
9
- import platform
10
- import warnings
11
- import requests
12
- import onnxruntime
13
-
14
- from importlib import metadata, import_module
15
-
16
- now_dir = os.getcwd()
17
- sys.path.append(now_dir)
18
-
19
- from main.configs.config import Config
20
- from main.tools.huggingface import HF_download_file
21
-
22
- translations = Config().translations
23
-
24
-
25
- class Separator:
26
- def __init__(self, logger=logging.getLogger(__name__), log_level=logging.INFO, log_formatter=None, model_file_dir="assets/models/uvr5", output_dir=None, output_format="wav", output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=44100, mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False}, demucs_params={"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True}):
27
- self.logger = logger
28
- self.log_level = log_level
29
- self.log_formatter = log_formatter
30
- self.log_handler = logging.StreamHandler()
31
-
32
- if self.log_formatter is None: self.log_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(module)s - %(message)s")
33
- self.log_handler.setFormatter(self.log_formatter)
34
-
35
- if not self.logger.hasHandlers(): self.logger.addHandler(self.log_handler)
36
- if log_level > logging.DEBUG: warnings.filterwarnings("ignore")
37
-
38
- self.logger.info(translations["separator_info"].format(output_dir=output_dir, output_format=output_format))
39
- self.model_file_dir = model_file_dir
40
-
41
- if output_dir is None:
42
- output_dir = now_dir
43
- self.logger.info(translations["output_dir_is_none"])
44
-
45
- self.output_dir = output_dir
46
-
47
- os.makedirs(self.model_file_dir, exist_ok=True)
48
- os.makedirs(self.output_dir, exist_ok=True)
49
-
50
- self.output_format = output_format
51
- self.output_bitrate = output_bitrate
52
-
53
- if self.output_format is None: self.output_format = "wav"
54
- self.normalization_threshold = normalization_threshold
55
- if normalization_threshold <= 0 or normalization_threshold > 1: raise ValueError(translations[">0or=1"])
56
-
57
- self.output_single_stem = output_single_stem
58
- if output_single_stem is not None: self.logger.debug(translations["output_single"].format(output_single_stem=output_single_stem))
59
-
60
- self.invert_using_spec = invert_using_spec
61
- if self.invert_using_spec: self.logger.debug(translations["step2"])
62
-
63
- self.sample_rate = int(sample_rate)
64
- self.arch_specific_params = {"MDX": mdx_params, "Demucs": demucs_params}
65
- self.torch_device = None
66
- self.torch_device_cpu = None
67
- self.torch_device_mps = None
68
- self.onnx_execution_provider = None
69
- self.model_instance = None
70
- self.model_is_uvr_vip = False
71
- self.model_friendly_name = None
72
- self.setup_accelerated_inferencing_device()
73
-
74
- def setup_accelerated_inferencing_device(self):
75
- system_info = self.get_system_info()
76
- self.log_onnxruntime_packages()
77
- self.setup_torch_device(system_info)
78
-
79
- def get_system_info(self):
80
- os_name = platform.system()
81
- os_version = platform.version()
82
- self.logger.info(f"{translations['os']}: {os_name} {os_version}")
83
- system_info = platform.uname()
84
- self.logger.info(translations["platform_info"].format(system_info=system_info, node=system_info.node, release=system_info.release, machine=system_info.machine, processor=system_info.processor))
85
- python_version = platform.python_version()
86
- self.logger.info(f"{translations['name_ver'].format(name='python')}: {python_version}")
87
- pytorch_version = torch.__version__
88
- self.logger.info(f"{translations['name_ver'].format(name='pytorch')}: {pytorch_version}")
89
-
90
- return system_info
91
-
92
- def log_onnxruntime_packages(self):
93
- onnxruntime_gpu_package = self.get_package_distribution("onnxruntime-gpu")
94
- onnxruntime_cpu_package = self.get_package_distribution("onnxruntime")
95
-
96
- if onnxruntime_gpu_package is not None: self.logger.info(f"{translations['install_onnx'].format(pu='GPU')}: {onnxruntime_gpu_package.version}")
97
- if onnxruntime_cpu_package is not None: self.logger.info(f"{translations['install_onnx'].format(pu='CPU')}: {onnxruntime_cpu_package.version}")
98
-
99
- def setup_torch_device(self, system_info):
100
- hardware_acceleration_enabled = False
101
- ort_providers = onnxruntime.get_available_providers()
102
- self.torch_device_cpu = torch.device("cpu")
103
-
104
- if torch.cuda.is_available():
105
- self.configure_cuda(ort_providers)
106
- hardware_acceleration_enabled = True
107
- elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and system_info.processor == "arm":
108
- self.configure_mps(ort_providers)
109
- hardware_acceleration_enabled = True
110
-
111
- if not hardware_acceleration_enabled:
112
- self.logger.info(translations["running_in_cpu"])
113
- self.torch_device = self.torch_device_cpu
114
- self.onnx_execution_provider = ["CPUExecutionProvider"]
115
-
116
- def configure_cuda(self, ort_providers):
117
- self.logger.info(translations["running_in_cuda"])
118
- self.torch_device = torch.device("cuda")
119
-
120
- if "CUDAExecutionProvider" in ort_providers:
121
- self.logger.info(translations["onnx_have"].format(have='CUDAExecutionProvider'))
122
- self.onnx_execution_provider = ["CUDAExecutionProvider"]
123
- else: self.logger.warning(translations["onnx_not_have"].format(have='CUDAExecutionProvider'))
124
-
125
- def configure_mps(self, ort_providers):
126
- self.logger.info(translations["set_torch_mps"])
127
- self.torch_device_mps = torch.device("mps")
128
- self.torch_device = self.torch_device_mps
129
-
130
- if "CoreMLExecutionProvider" in ort_providers:
131
- self.logger.info(translations["onnx_have"].format(have='CoreMLExecutionProvider'))
132
- self.onnx_execution_provider = ["CoreMLExecutionProvider"]
133
- else: self.logger.warning(translations["onnx_not_have"].format(have='CoreMLExecutionProvider'))
134
-
135
- def get_package_distribution(self, package_name):
136
- try:
137
- return metadata.distribution(package_name)
138
- except metadata.PackageNotFoundError:
139
- self.logger.debug(translations["python_not_install"].format(package_name=package_name))
140
- return None
141
-
142
- def get_model_hash(self, model_path):
143
- self.logger.debug(translations["hash"].format(model_path=model_path))
144
-
145
- try:
146
- with open(model_path, "rb") as f:
147
- f.seek(-10000 * 1024, 2)
148
- return hashlib.md5(f.read()).hexdigest()
149
- except IOError as e:
150
- self.logger.error(translations["ioerror"].format(e=e))
151
- return hashlib.md5(open(model_path, "rb").read()).hexdigest()
152
-
153
- def download_file_if_not_exists(self, url, output_path):
154
- if os.path.isfile(output_path):
155
- self.logger.debug(translations["cancel_download"].format(output_path=output_path))
156
- return
157
-
158
- self.logger.debug(translations["download_model"].format(url=url, output_path=output_path))
159
- HF_download_file(url, output_path)
160
-
161
- def print_uvr_vip_message(self):
162
- if self.model_is_uvr_vip:
163
- self.logger.warning(translations["vip_model"].format(model_friendly_name=self.model_friendly_name))
164
- self.logger.warning(translations["vip_print"])
165
-
166
- def list_supported_model_files(self):
167
- response = requests.get(codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/enj/znva/wfba/hie_zbqryf.wfba", "rot13"))
168
- response.raise_for_status()
169
- model_downloads_list = response.json()
170
- self.logger.debug(translations["load_download_json"])
171
-
172
- return {"MDX": {**model_downloads_list["mdx_download_list"], **model_downloads_list["mdx_download_vip_list"]}, "Demucs": {key: value for key, value in model_downloads_list["demucs_download_list"].items() if key.startswith("Demucs v4")}}
173
-
174
- def download_model_files(self, model_filename):
175
- model_path = os.path.join(self.model_file_dir, model_filename)
176
- supported_model_files_grouped = self.list_supported_model_files()
177
-
178
- yaml_config_filename = None
179
- self.logger.debug(translations["search_model"].format(model_filename=model_filename))
180
-
181
- for model_type, model_list in supported_model_files_grouped.items():
182
- for model_friendly_name, model_download_list in model_list.items():
183
- self.model_is_uvr_vip = "VIP" in model_friendly_name
184
- model_repo_url_prefix = codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/hie5_zbqryf", "rot13")
185
-
186
- if isinstance(model_download_list, str) and model_download_list == model_filename:
187
- self.logger.debug(translations["single_model"].format(model_friendly_name=model_friendly_name))
188
- self.model_friendly_name = model_friendly_name
189
-
190
- try:
191
- self.download_file_if_not_exists(f"{model_repo_url_prefix}/MDX/{model_filename}", model_path)
192
- except RuntimeError:
193
- self.logger.warning(translations["not_found_model"])
194
- self.download_file_if_not_exists(f"{model_repo_url_prefix}/Demucs/{model_filename}", model_path)
195
-
196
- self.print_uvr_vip_message()
197
- self.logger.debug(translations["single_model_path"].format(model_path=model_path))
198
-
199
- return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
200
- elif isinstance(model_download_list, dict):
201
- this_model_matches_input_filename = False
202
-
203
- for file_name, file_url in model_download_list.items():
204
- if file_name == model_filename or file_url == model_filename:
205
- self.logger.debug(translations["find_model"].format(model_filename=model_filename, model_friendly_name=model_friendly_name))
206
- this_model_matches_input_filename = True
207
-
208
- if this_model_matches_input_filename:
209
- self.logger.debug(translations["find_models"].format(model_friendly_name=model_friendly_name))
210
- self.model_friendly_name = model_friendly_name
211
- self.print_uvr_vip_message()
212
-
213
- for config_key, config_value in model_download_list.items():
214
- self.logger.debug(f"{translations['find_path']}: {config_key} -> {config_value}")
215
-
216
- if config_value.startswith("http"): self.download_file_if_not_exists(config_value, os.path.join(self.model_file_dir, config_key))
217
- elif config_key.endswith(".ckpt"):
218
- try:
219
- self.download_file_if_not_exists(f"{model_repo_url_prefix}/Demucs/{config_key}", os.path.join(self.model_file_dir, config_key))
220
- except RuntimeError:
221
- self.logger.warning(translations["not_found_model_warehouse"])
222
-
223
- if model_filename.endswith(".yaml"):
224
- self.logger.warning(translations["yaml_warning"].format(model_filename=model_filename))
225
- self.logger.warning(translations["yaml_warning_2"].format(config_key=config_key))
226
- self.logger.warning(translations["yaml_warning_3"])
227
-
228
- model_filename = config_key
229
- model_path = os.path.join(self.model_file_dir, f"{model_filename}")
230
-
231
- yaml_config_filename = config_value
232
- yaml_config_filepath = os.path.join(self.model_file_dir, yaml_config_filename)
233
-
234
- try:
235
- self.download_file_if_not_exists(f"{model_repo_url_prefix}/mdx_c_configs/{yaml_config_filename}", yaml_config_filepath)
236
- except RuntimeError:
237
- self.logger.debug(translations["yaml_debug"])
238
- else: self.download_file_if_not_exists(f"{model_repo_url_prefix}/Demucs/{config_value}", os.path.join(self.model_file_dir, config_value))
239
-
240
- self.logger.debug(translations["download_model_friendly"].format(model_friendly_name=model_friendly_name, model_path=model_path))
241
- return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
242
-
243
- raise ValueError(translations["not_found_model_2"].format(model_filename=model_filename))
244
-
245
- def load_model_data_from_yaml(self, yaml_config_filename):
246
- model_data_yaml_filepath = os.path.join(self.model_file_dir, yaml_config_filename) if not os.path.exists(yaml_config_filename) else yaml_config_filename
247
- self.logger.debug(translations["load_yaml"].format(model_data_yaml_filepath=model_data_yaml_filepath))
248
-
249
- model_data = yaml.load(open(model_data_yaml_filepath, encoding="utf-8"), Loader=yaml.FullLoader)
250
- self.logger.debug(translations["load_yaml_2"].format(model_data=model_data))
251
-
252
- if "roformer" in model_data_yaml_filepath: model_data["is_roformer"] = True
253
- return model_data
254
-
255
- def load_model_data_using_hash(self, model_path):
256
- self.logger.debug(translations["hash_md5"])
257
- model_hash = self.get_model_hash(model_path)
258
-
259
- self.logger.debug(translations["model_hash"].format(model_path=model_path, model_hash=model_hash))
260
- mdx_model_data_path = codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/enj/znva/wfba/zbqry_qngn.wfba", "rot13")
261
- self.logger.debug(translations["mdx_data"].format(mdx_model_data_path=mdx_model_data_path))
262
-
263
- response = requests.get(mdx_model_data_path)
264
- response.raise_for_status()
265
-
266
- mdx_model_data_object = response.json()
267
- self.logger.debug(translations["load_mdx"])
268
-
269
- if model_hash in mdx_model_data_object: model_data = mdx_model_data_object[model_hash]
270
- else: raise ValueError(translations["model_not_support"].format(model_hash=model_hash))
271
-
272
- self.logger.debug(translations["uvr_json"].format(model_hash=model_hash, model_data=model_data))
273
- return model_data
274
-
275
- def load_model(self, model_filename):
276
- self.logger.info(translations["loading_model"].format(model_filename=model_filename))
277
- load_model_start_time = time.perf_counter()
278
- model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
279
- self.logger.debug(translations["download_model_friendly_2"].format(model_friendly_name=model_friendly_name, model_path=model_path))
280
-
281
- if model_path.lower().endswith(".yaml"): yaml_config_filename = model_path
282
-
283
- common_params = {"logger": self.logger, "log_level": self.log_level, "torch_device": self.torch_device, "torch_device_cpu": self.torch_device_cpu, "torch_device_mps": self.torch_device_mps, "onnx_execution_provider": self.onnx_execution_provider, "model_name": model_filename.split(".")[0], "model_path": model_path, "model_data": self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename is not None else self.load_model_data_using_hash(model_path), "output_format": self.output_format, "output_bitrate": self.output_bitrate, "output_dir": self.output_dir, "normalization_threshold": self.normalization_threshold, "output_single_stem": self.output_single_stem, "invert_using_spec": self.invert_using_spec, "sample_rate": self.sample_rate}
284
- separator_classes = {"MDX": "mdx_separator.MDXSeparator", "Demucs": "demucs_separator.DemucsSeparator"}
285
-
286
- if model_type not in self.arch_specific_params or model_type not in separator_classes: raise ValueError(translations["model_type_not_support"].format(model_type=model_type))
287
- if model_type == "Demucs" and sys.version_info < (3, 10): raise Exception(translations["demucs_not_support_python<3.10"])
288
-
289
- self.logger.debug(f"{translations['import_module']} {model_type}: {separator_classes[model_type]}")
290
- module_name, class_name = separator_classes[model_type].split(".")
291
- separator_class = getattr(import_module(f"main.library.architectures.{module_name}"), class_name)
292
-
293
- self.logger.debug(f"{translations['initialization']} {model_type}: {separator_class}")
294
- self.model_instance = separator_class(common_config=common_params, arch_config=self.arch_specific_params[model_type])
295
-
296
- self.logger.debug(translations["loading_model_success"])
297
- self.logger.info(f"{translations['loading_model_duration']}: {time.strftime('%H:%M:%S', time.gmtime(int(time.perf_counter() - load_model_start_time)))}")
298
-
299
- def separate(self, audio_file_path):
300
- self.logger.info(f"{translations['starting_separator']}: {audio_file_path}")
301
- separate_start_time = time.perf_counter()
302
-
303
- self.logger.debug(translations["normalization"].format(normalization_threshold=self.normalization_threshold))
304
- output_files = self.model_instance.separate(audio_file_path)
305
-
306
- self.model_instance.clear_gpu_cache()
307
- self.model_instance.clear_file_specific_paths()
308
-
309
- self.print_uvr_vip_message()
310
-
311
- self.logger.debug(translations["separator_success_3"])
312
- self.logger.info(f"{translations['separator_duration']}: {time.strftime('%H:%M:%S', time.gmtime(int(time.perf_counter() - separate_start_time)))}")
313
- return output_files
314
-
315
- def download_model_and_data(self, model_filename):
316
- self.logger.info(translations["loading_separator_model"].format(model_filename=model_filename))
317
- model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
318
-
319
- if model_path.lower().endswith(".yaml"): yaml_config_filename = model_path
320
- self.logger.info(translations["downloading_model"].format(model_type=model_type, model_friendly_name=model_friendly_name, model_path=model_path, model_data_dict_size=len(self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename is not None else self.load_model_data_using_hash(model_path))))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/algorithm/stftpitchshift.py DELETED
@@ -1,250 +0,0 @@
1
- import numpy as np
2
-
3
- from numpy.lib.stride_tricks import sliding_window_view
4
-
5
- def istft(frames, framesize, hopsize):
6
- frames = np.atleast_2d(frames)
7
- assert frames.ndim == 2
8
-
9
- analysis_window_size = np.ravel(framesize)[0]
10
- synthesis_window_size = np.ravel(framesize)[-1]
11
-
12
- assert analysis_window_size >= synthesis_window_size
13
-
14
- A = asymmetric_analysis_window(analysis_window_size, synthesis_window_size) if analysis_window_size != synthesis_window_size else symmetric_window(analysis_window_size)
15
- S = asymmetric_synthesis_window(analysis_window_size, synthesis_window_size) if analysis_window_size != synthesis_window_size else symmetric_window(synthesis_window_size)
16
-
17
- W = S * hopsize / np.sum(A * S)
18
- N = frames.shape[0] * hopsize + analysis_window_size
19
-
20
- y = np.zeros((N), float)
21
-
22
- frames[:, 0] = 0
23
- frames[:, -1] = 0
24
- frames0 = sliding_window_view(y, analysis_window_size, writeable=True)[::hopsize]
25
- frames1 = np.fft.irfft(frames, axis=-1, norm='forward') * W
26
-
27
- for i in range(min(len(frames0), len(frames1))):
28
- frames0[i] += frames1[i]
29
-
30
- return y
31
-
32
- def asymmetric_synthesis_window(analysis_window_size, synthesis_window_size):
33
- n = analysis_window_size
34
- m = synthesis_window_size // 2
35
-
36
- right = symmetric_window(2 * m)
37
- window = np.zeros(n)
38
-
39
- window[n-m-m:n-m] = np.square(right[:m]) / symmetric_window(2 * n - 2 * m)[n-m-m:n-m]
40
- window[-m:] = right[-m:]
41
-
42
- return window
43
-
44
- def asymmetric_analysis_window(analysis_window_size, synthesis_window_size):
45
- n = analysis_window_size
46
- m = synthesis_window_size // 2
47
-
48
- window = np.zeros(n)
49
- window[:n-m] = symmetric_window(2 * n - 2 * m)[:n-m]
50
- window[-m:] = symmetric_window(2 * m)[-m:]
51
-
52
- return window
53
-
54
- def symmetric_window(symmetric_window_size):
55
- n = symmetric_window_size
56
- window = 0.5 - 0.5 * np.cos(2 * np.pi * np.arange(n) / n)
57
-
58
- return window
59
-
60
- def stft(x, framesize, hopsize):
61
- x = np.atleast_1d(x)
62
- assert x.ndim == 1
63
-
64
- analysis_window_size = np.ravel(framesize)[0]
65
- synthesis_window_size = np.ravel(framesize)[-1]
66
-
67
- assert analysis_window_size >= synthesis_window_size
68
-
69
- W = asymmetric_analysis_window(analysis_window_size, synthesis_window_size) if analysis_window_size != synthesis_window_size else symmetric_window(analysis_window_size)
70
-
71
- frames0 = sliding_window_view(x, analysis_window_size, writeable=False)[::hopsize]
72
- frames1 = np.fft.rfft(frames0 * W, axis=-1, norm='forward')
73
-
74
- return frames1
75
-
76
- def normalize(frames, frames0):
77
- for i in range(len(frames)):
78
- a = np.real(frames0[i])
79
- b = np.real(frames[i])
80
- a = np.dot(a, a)
81
- b = np.dot(b, b)
82
-
83
- if b == 0: continue
84
- frames[i] = np.real(frames[i]) * np.sqrt(a / b) + 1j * np.imag(frames[i])
85
-
86
- return frames
87
-
88
- def lowpass(cepstrum, quefrency):
89
- cepstrum[1:quefrency] *= 2
90
- cepstrum[quefrency+1:] = 0
91
-
92
- return cepstrum
93
-
94
- def lifter(frames, quefrency):
95
- envelopes = np.zeros(frames.shape)
96
-
97
- for i, frame in enumerate(frames):
98
- with np.errstate(divide='ignore', invalid='ignore'):
99
- spectrum = np.log10(np.real(frame))
100
-
101
- envelopes[i] = np.power(10, np.real(np.fft.rfft(lowpass(np.fft.irfft(spectrum, norm='forward'), quefrency), norm='forward')))
102
-
103
- return envelopes
104
-
105
- def resample(x, factor):
106
- if factor == 1: return x.copy()
107
- y = np.zeros(x.shape, dtype=x.dtype)
108
-
109
- n = len(x)
110
- m = int(n * factor)
111
-
112
- i = np.arange(min(n, m))
113
- k = i * (n / m)
114
-
115
- j = np.trunc(k).astype(int)
116
- k = k - j
117
-
118
- ok = (0 <= j) & (j < n - 1)
119
- y[i[ok]] = k[ok] * x[j[ok] + 1] + (1 - k[ok]) * x[j[ok]]
120
-
121
- return y
122
-
123
- def shiftpitch(frames, factors, samplerate):
124
- for i in range(len(frames)):
125
- magnitudes = np.vstack([resample(np.real(frames[i]), factor) for factor in factors])
126
- frequencies = np.vstack([resample(np.imag(frames[i]), factor) * factor for factor in factors])
127
-
128
- magnitudes[(frequencies <= 0) | (frequencies >= samplerate / 2)] = 0
129
- mask = np.argmax(magnitudes, axis=0)
130
-
131
- magnitudes = np.take_along_axis(magnitudes, mask[None,:], axis=0)
132
- frequencies = np.take_along_axis(frequencies, mask[None,:], axis=0)
133
-
134
- frames[i] = magnitudes + 1j * frequencies
135
-
136
- return frames
137
-
138
- def wrap(x):
139
- return (x + np.pi) % (2 * np.pi) - np.pi
140
-
141
- def encode(frames, framesize, hopsize, samplerate):
142
- M, N = frames.shape
143
- analysis_framesize = np.ravel(framesize)[0]
144
-
145
- freqinc = samplerate / analysis_framesize
146
- phaseinc = 2 * np.pi * hopsize / analysis_framesize
147
-
148
- buffer = np.zeros(N)
149
- data = np.zeros((M, N), complex)
150
-
151
- for m, frame in enumerate(frames):
152
- arg = np.angle(frame)
153
- delta = arg - buffer
154
-
155
- buffer = arg
156
-
157
- i = np.arange(N)
158
- data[m] = np.abs(frame) + 1j * ((i + (wrap(delta - i * phaseinc) / phaseinc)) * freqinc)
159
-
160
- return data
161
-
162
- def decode(frames, framesize, hopsize, samplerate):
163
- M, N = frames.shape
164
- analysis_framesize = np.ravel(framesize)[0]
165
- synthesis_framesize = np.ravel(framesize)[-1]
166
-
167
- freqinc = samplerate / analysis_framesize
168
- phaseinc = 2 * np.pi * hopsize / analysis_framesize
169
- timeshift = 2 * np.pi * synthesis_framesize * np.arange(N) / N if synthesis_framesize != analysis_framesize else 0
170
-
171
- buffer = np.zeros(N)
172
- data = np.zeros((M, N), complex)
173
-
174
- for m, frame in enumerate(frames):
175
- i = np.arange(N)
176
- delta = (i + ((np.imag(frame) - i * freqinc) / freqinc)) * phaseinc
177
- buffer += delta
178
- arg = buffer.copy()
179
- arg -= timeshift
180
- data[m] = np.real(frame) * np.exp(1j * arg)
181
-
182
- return data
183
-
184
- class StftPitchShift:
185
- def __init__(self, framesize, hopsize, samplerate):
186
- self.framesize = framesize
187
- self.hopsize = hopsize
188
- self.samplerate = samplerate
189
-
190
- def shiftpitch(self, input, factors = 1, quefrency = 0, distortion = 1, normalization = False):
191
- input = np.atleast_1d(input)
192
- dtype = input.dtype
193
- shape = input.shape
194
-
195
- input = np.squeeze(input)
196
- if input.ndim != 1: raise ValueError('input.ndim != 1')
197
-
198
- if np.issubdtype(dtype, np.integer):
199
- a, b = np.iinfo(dtype).min, np.iinfo(dtype).max
200
- input = ((input.astype(float) - a) / (b - a)) * 2 - 1
201
- elif not np.issubdtype(dtype, np.floating): raise TypeError('not np.issubdtype(dtype, np.floating)')
202
-
203
- def isnotnormal(x):
204
- return (np.isinf(x)) | (np.isnan(x)) | (abs(x) < np.finfo(x.dtype).tiny)
205
-
206
- framesize = self.framesize
207
- hopsize = self.hopsize
208
- samplerate = self.samplerate
209
-
210
- factors = np.asarray(factors).flatten()
211
- quefrency = int(quefrency * samplerate)
212
-
213
- frames = encode(stft(input, framesize, hopsize), framesize, hopsize, samplerate)
214
-
215
- if normalization: frames0 = frames.copy()
216
-
217
- if quefrency:
218
- envelopes = lifter(frames, quefrency)
219
- mask = isnotnormal(envelopes)
220
-
221
- frames.real /= envelopes
222
- frames.real[mask] = 0
223
-
224
- if distortion != 1:
225
- envelopes[mask] = 0
226
-
227
- for i in range(len(envelopes)):
228
- envelopes[i] = resample(envelopes[i], distortion)
229
-
230
- mask = isnotnormal(envelopes)
231
-
232
- frames = shiftpitch(frames, factors, samplerate)
233
- frames.real *= envelopes
234
- frames.real[mask] = 0
235
- else: frames = shiftpitch(frames, factors, samplerate)
236
-
237
- if normalization: frames = normalize(frames, frames0)
238
-
239
- output = istft(decode(frames, framesize, hopsize, samplerate), framesize, hopsize)
240
- output.resize(shape, refcheck=False)
241
-
242
- if np.issubdtype(dtype, np.integer):
243
- a, b = np.iinfo(dtype).min, np.iinfo(dtype).max
244
- output = (((output + 1) / 2) * (b - a) + a).clip(a, b).astype(dtype)
245
- elif output.dtype != dtype: output = output.astype(dtype)
246
-
247
- assert output.dtype == dtype
248
- assert output.shape == shape
249
-
250
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/algorithm/synthesizers.py DELETED
@@ -1,490 +0,0 @@
1
- import os
2
- import sys
3
- import math
4
- import torch
5
- import numpy as np
6
- import torch.nn.functional as F
7
-
8
- from torch.nn.utils import remove_weight_norm
9
- from torch.utils.checkpoint import checkpoint
10
- from torch.nn.utils.parametrizations import weight_norm
11
-
12
- sys.path.append(os.getcwd())
13
-
14
- from .modules import WaveNet
15
- from .refinegan import RefineGANGenerator
16
- from .mrf_hifigan import HiFiGANMRFGenerator
17
- from .residuals import ResidualCouplingBlock, ResBlock, LRELU_SLOPE
18
- from .commons import init_weights, slice_segments, rand_slice_segments, sequence_mask, convert_pad_shape
19
-
20
-
21
- class Generator(torch.nn.Module):
22
- def __init__(self, initial_channel, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
23
- super(Generator, self).__init__()
24
- self.num_kernels = len(resblock_kernel_sizes)
25
- self.num_upsamples = len(upsample_rates)
26
- self.conv_pre = torch.nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
27
- self.ups_and_resblocks = torch.nn.ModuleList()
28
-
29
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
30
- self.ups_and_resblocks.append(weight_norm(torch.nn.ConvTranspose1d(upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2)))
31
- ch = upsample_initial_channel // (2 ** (i + 1))
32
- for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
33
- self.ups_and_resblocks.append(ResBlock(ch, k, d))
34
-
35
- self.conv_post = torch.nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False)
36
- self.ups_and_resblocks.apply(init_weights)
37
- if gin_channels != 0: self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1)
38
-
39
- def forward(self, x, g = None):
40
- x = self.conv_pre(x)
41
- if g is not None: x = x + self.cond(g)
42
-
43
- resblock_idx = 0
44
-
45
- for _ in range(self.num_upsamples):
46
- x = self.ups_and_resblocks[resblock_idx](F.leaky_relu(x, LRELU_SLOPE))
47
- resblock_idx += 1
48
- xs = 0
49
-
50
- for _ in range(self.num_kernels):
51
- xs += self.ups_and_resblocks[resblock_idx](x)
52
- resblock_idx += 1
53
-
54
- x = xs / self.num_kernels
55
-
56
- return torch.tanh(self.conv_post(F.leaky_relu(x)))
57
-
58
- def __prepare_scriptable__(self):
59
- for l in self.ups_and_resblocks:
60
- for hook in l._forward_pre_hooks.values():
61
- if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(l)
62
-
63
- return self
64
-
65
- def remove_weight_norm(self):
66
- for l in self.ups_and_resblocks:
67
- remove_weight_norm(l)
68
-
69
- class SineGen(torch.nn.Module):
70
- def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0, flag_for_pulse=False):
71
- super(SineGen, self).__init__()
72
- self.sine_amp = sine_amp
73
- self.noise_std = noise_std
74
- self.harmonic_num = harmonic_num
75
- self.dim = self.harmonic_num + 1
76
- self.sampling_rate = samp_rate
77
- self.voiced_threshold = voiced_threshold
78
-
79
- def _f02uv(self, f0):
80
- return torch.ones_like(f0) * (f0 > self.voiced_threshold)
81
-
82
- def _f02sine(self, f0, upp):
83
- rad = f0 / self.sampling_rate * torch.arange(1, upp + 1, dtype=f0.dtype, device=f0.device)
84
- rad += F.pad((torch.fmod(rad[:, :-1, -1:].float() + 0.5, 1.0) - 0.5).cumsum(dim=1).fmod(1.0).to(f0), (0, 0, 1, 0), mode='constant')
85
- rad = rad.reshape(f0.shape[0], -1, 1)
86
- rad *= torch.arange(1, self.dim + 1, dtype=f0.dtype, device=f0.device).reshape(1, 1, -1)
87
- rand_ini = torch.rand(1, 1, self.dim, device=f0.device)
88
- rand_ini[..., 0] = 0
89
- rad += rand_ini
90
-
91
- return torch.sin(2 * np.pi * rad)
92
-
93
- def forward(self, f0, upp):
94
- with torch.no_grad():
95
- f0 = f0.unsqueeze(-1)
96
- sine_waves = self._f02sine(f0, upp) * self.sine_amp
97
- uv = F.interpolate(self._f02uv(f0).transpose(2, 1), scale_factor=float(upp), mode="nearest").transpose(2, 1)
98
- sine_waves = sine_waves * uv + ((uv * self.noise_std + (1 - uv) * self.sine_amp / 3) * torch.randn_like(sine_waves))
99
-
100
- return sine_waves
101
-
102
- class SourceModuleHnNSF(torch.nn.Module):
103
- def __init__(self, sample_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0):
104
- super(SourceModuleHnNSF, self).__init__()
105
- self.sine_amp = sine_amp
106
- self.noise_std = add_noise_std
107
- self.l_sin_gen = SineGen(sample_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
108
- self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
109
- self.l_tanh = torch.nn.Tanh()
110
-
111
- def forward(self, x, upsample_factor = 1):
112
- return self.l_tanh(self.l_linear(self.l_sin_gen(x, upsample_factor).to(dtype=self.l_linear.weight.dtype)))
113
-
114
- class GeneratorNSF(torch.nn.Module):
115
- def __init__(self, initial_channel, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels, sr, checkpointing = False):
116
- super(GeneratorNSF, self).__init__()
117
- self.num_kernels = len(resblock_kernel_sizes)
118
- self.num_upsamples = len(upsample_rates)
119
- self.upp = math.prod(upsample_rates)
120
- self.f0_upsamp = torch.nn.Upsample(scale_factor=self.upp)
121
- self.m_source = SourceModuleHnNSF(sample_rate=sr, harmonic_num=0)
122
-
123
- self.conv_pre = torch.nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
124
- self.checkpointing = checkpointing
125
-
126
- self.ups = torch.nn.ModuleList()
127
- self.noise_convs = torch.nn.ModuleList()
128
-
129
- channels = [upsample_initial_channel // (2 ** (i + 1)) for i in range(self.num_upsamples)]
130
- stride_f0s = [math.prod(upsample_rates[i + 1 :]) if i + 1 < self.num_upsamples else 1 for i in range(self.num_upsamples)]
131
-
132
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
133
- self.ups.append(weight_norm(torch.nn.ConvTranspose1d(upsample_initial_channel // (2**i), channels[i], k, u, padding=((k - u) // 2) if u % 2 == 0 else (u // 2 + u % 2), output_padding=u % 2)))
134
- stride = stride_f0s[i]
135
- kernel = 1 if stride == 1 else stride * 2 - stride % 2
136
- self.noise_convs.append(torch.nn.Conv1d(1, channels[i], kernel_size=kernel, stride=stride, padding=0 if stride == 1 else (kernel - stride) // 2))
137
-
138
- self.resblocks = torch.nn.ModuleList([ResBlock(channels[i], k, d) for i in range(len(self.ups)) for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes)])
139
- self.conv_post = torch.nn.Conv1d(channels[-1], 1, 7, 1, padding=3, bias=False)
140
-
141
- self.ups.apply(init_weights)
142
- if gin_channels != 0: self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1)
143
-
144
- def forward(self, x, f0, g = None):
145
- har_source = self.m_source(f0, self.upp).transpose(1, 2)
146
- x = self.conv_pre(x)
147
- if g is not None: x += self.cond(g)
148
-
149
- for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)):
150
- x = F.leaky_relu(x, LRELU_SLOPE)
151
-
152
- if self.training and self.checkpointing:
153
- x = checkpoint(ups, x, use_reentrant=False) + noise_convs(har_source)
154
- xs = sum([checkpoint(resblock, x, use_reentrant=False) for j, resblock in enumerate(self.resblocks) if j in range(i * self.num_kernels, (i + 1) * self.num_kernels)])
155
- else:
156
- x = ups(x) + noise_convs(har_source)
157
- xs = sum([resblock(x) for j, resblock in enumerate(self.resblocks) if j in range(i * self.num_kernels, (i + 1) * self.num_kernels)])
158
-
159
- x = xs / self.num_kernels
160
-
161
- return torch.tanh(self.conv_post(F.leaky_relu(x)))
162
-
163
- def remove_weight_norm(self):
164
- for l in self.ups:
165
- remove_weight_norm(l)
166
-
167
- for l in self.resblocks:
168
- l.remove_weight_norm()
169
-
170
- class LayerNorm(torch.nn.Module):
171
- def __init__(self, channels, eps=1e-5, onnx=False):
172
- super().__init__()
173
- self.channels = channels
174
- self.eps = eps
175
- self.onnx = onnx
176
- self.gamma = torch.nn.Parameter(torch.ones(channels))
177
- self.beta = torch.nn.Parameter(torch.zeros(channels))
178
-
179
- def forward(self, x):
180
- x = x.transpose(1, -1)
181
- return (F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) if self.onnx else F.layer_norm(x, (x.size(-1),), self.gamma, self.beta, self.eps)).transpose(1, -1)
182
-
183
- class MultiHeadAttention(torch.nn.Module):
184
- def __init__(self, channels, out_channels, n_heads, p_dropout=0.0, window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False, onnx=False):
185
- super().__init__()
186
- assert channels % n_heads == 0
187
- self.channels = channels
188
- self.out_channels = out_channels
189
- self.n_heads = n_heads
190
- self.p_dropout = p_dropout
191
- self.window_size = window_size
192
- self.heads_share = heads_share
193
- self.block_length = block_length
194
- self.proximal_bias = proximal_bias
195
- self.proximal_init = proximal_init
196
- self.onnx = onnx
197
- self.attn = None
198
- self.k_channels = channels // n_heads
199
- self.conv_q = torch.nn.Conv1d(channels, channels, 1)
200
- self.conv_k = torch.nn.Conv1d(channels, channels, 1)
201
- self.conv_v = torch.nn.Conv1d(channels, channels, 1)
202
- self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
203
- self.drop = torch.nn.Dropout(p_dropout)
204
-
205
- if window_size is not None:
206
- n_heads_rel = 1 if heads_share else n_heads
207
- rel_stddev = self.k_channels**-0.5
208
-
209
- self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
210
- self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
211
-
212
- torch.nn.init.xavier_uniform_(self.conv_q.weight)
213
- torch.nn.init.xavier_uniform_(self.conv_k.weight)
214
- torch.nn.init.xavier_uniform_(self.conv_v.weight)
215
-
216
- if proximal_init:
217
- with torch.no_grad():
218
- self.conv_k.weight.copy_(self.conv_q.weight)
219
- self.conv_k.bias.copy_(self.conv_q.bias)
220
-
221
- def forward(self, x, c, attn_mask=None):
222
- q, k, v = self.conv_q(x), self.conv_k(c), self.conv_v(c)
223
- x, self.attn = self.attention(q, k, v, mask=attn_mask)
224
-
225
- return self.conv_o(x)
226
-
227
- def attention(self, query, key, value, mask=None):
228
- b, d, t_s, t_t = (*key.size(), query.size(2))
229
- query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
230
- key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
231
- scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
232
-
233
- if self.window_size is not None:
234
- assert (t_s == t_t), "(t_s == t_t)"
235
- scores = scores + self._relative_position_to_absolute_position(self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), self._get_relative_embeddings(self.emb_rel_k, t_s, onnx=self.onnx)), onnx=self.onnx)
236
-
237
- if self.proximal_bias:
238
- assert t_s == t_t, "t_s == t_t"
239
- scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
240
-
241
- if mask is not None:
242
- scores = scores.masked_fill(mask == 0, -1e4)
243
- if self.block_length is not None:
244
- assert (t_s == t_t), "(t_s == t_t)"
245
- scores = scores.masked_fill((torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)) == 0, -1e4)
246
-
247
- p_attn = self.drop(F.softmax(scores, dim=-1))
248
- output = torch.matmul(p_attn, value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3))
249
-
250
- if self.window_size is not None: output = output + self._matmul_with_relative_values(self._absolute_position_to_relative_position(p_attn, onnx=self.onnx), self._get_relative_embeddings(self.emb_rel_v, t_s, onnx=self.onnx))
251
- return (output.transpose(2, 3).contiguous().view(b, d, t_t)), p_attn
252
-
253
- def _matmul_with_relative_values(self, x, y):
254
- return torch.matmul(x, y.unsqueeze(0))
255
-
256
- def _matmul_with_relative_keys(self, x, y):
257
- return torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
258
-
259
- def _get_relative_embeddings(self, relative_embeddings, length, onnx=False):
260
- if onnx:
261
- pad_length = torch.clamp(length - (self.window_size + 1), min=0)
262
- slice_start_position = torch.clamp((self.window_size + 1) - length, min=0)
263
-
264
- return (F.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0]) if pad_length > 0 else relative_embeddings)[:, slice_start_position:(slice_start_position + 2 * length - 1)]
265
- else:
266
- pad_length = max(length - (self.window_size + 1), 0)
267
- slice_start_position = max((self.window_size + 1) - length, 0)
268
-
269
- return (F.pad(relative_embeddings, convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) if pad_length > 0 else relative_embeddings)[:, slice_start_position:(slice_start_position + 2 * length - 1)]
270
-
271
- def _relative_position_to_absolute_position(self, x, onnx=False):
272
- batch, heads, length, _ = x.size()
273
-
274
- return (F.pad(F.pad(x, [0, 1, 0, 0, 0, 0, 0, 0]).view([batch, heads, length * 2 * length]), [0, length - 1, 0, 0, 0, 0]).view([batch, heads, length + 1, 2 * length - 1]) if onnx else F.pad(F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])).view([batch, heads, length * 2 * length]), convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])).view([batch, heads, length + 1, 2 * length - 1]))[:, :, :length, length - 1 :]
275
-
276
- def _absolute_position_to_relative_position(self, x, onnx=False):
277
- batch, heads, length, _ = x.size()
278
-
279
- return (F.pad(F.pad(x, [0, length - 1, 0, 0, 0, 0, 0, 0]).view([batch, heads, length*length + length * (length - 1)]), [length, 0, 0, 0, 0, 0]).view([batch, heads, length, 2 * length]) if onnx else F.pad(F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])).view([batch, heads, length**2 + length * (length - 1)]), convert_pad_shape([[0, 0], [0, 0], [length, 0]])).view([batch, heads, length, 2 * length]))[:, :, :, 1:]
280
-
281
- def _attention_bias_proximal(self, length):
282
- r = torch.arange(length, dtype=torch.float32)
283
-
284
- return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs((torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)))), 0), 0)
285
-
286
- class FFN(torch.nn.Module):
287
- def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0, activation=None, causal=False, onnx=False):
288
- super().__init__()
289
- self.in_channels = in_channels
290
- self.out_channels = out_channels
291
- self.filter_channels = filter_channels
292
- self.kernel_size = kernel_size
293
- self.p_dropout = p_dropout
294
- self.activation = activation
295
- self.causal = causal
296
- self.onnx = onnx
297
- self.padding = self._causal_padding if causal else self._same_padding
298
- self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size)
299
- self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size)
300
- self.drop = torch.nn.Dropout(p_dropout)
301
-
302
- def forward(self, x, x_mask):
303
- x = self.conv_1(self.padding(x * x_mask))
304
-
305
- return self.conv_2(self.padding(self.drop(((x * torch.sigmoid(1.702 * x)) if self.activation == "gelu" else torch.relu(x))) * x_mask)) * x_mask
306
-
307
- def _causal_padding(self, x):
308
- if self.kernel_size == 1: return x
309
-
310
- return F.pad(x, [self.kernel_size - 1, 0, 0, 0, 0, 0]) if self.onnx else F.pad(x, convert_pad_shape([[0, 0], [0, 0], [(self.kernel_size - 1), 0]]))
311
-
312
- def _same_padding(self, x):
313
- if self.kernel_size == 1: return x
314
-
315
- return F.pad(x, [(self.kernel_size - 1) // 2, self.kernel_size // 2, 0, 0, 0, 0]) if self.onnx else F.pad(x, convert_pad_shape([[0, 0], [0, 0], [((self.kernel_size - 1) // 2), (self.kernel_size // 2)]]))
316
-
317
- class Encoder(torch.nn.Module):
318
- def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.0, window_size=10, onnx=False, **kwargs):
319
- super().__init__()
320
- self.hidden_channels = hidden_channels
321
- self.filter_channels = filter_channels
322
- self.n_heads = n_heads
323
- self.n_layers = n_layers
324
- self.kernel_size = kernel_size
325
- self.p_dropout = p_dropout
326
- self.window_size = window_size
327
- self.drop = torch.nn.Dropout(p_dropout)
328
- self.attn_layers = torch.nn.ModuleList()
329
- self.norm_layers_1 = torch.nn.ModuleList()
330
- self.ffn_layers = torch.nn.ModuleList()
331
- self.norm_layers_2 = torch.nn.ModuleList()
332
-
333
- for _ in range(self.n_layers):
334
- self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size, onnx=onnx))
335
- self.norm_layers_1.append(LayerNorm(hidden_channels, onnx=onnx))
336
-
337
- self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, onnx=onnx))
338
- self.norm_layers_2.append(LayerNorm(hidden_channels, onnx=onnx))
339
-
340
- def forward(self, x, x_mask):
341
- attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
342
- x = x * x_mask
343
-
344
- for i in range(self.n_layers):
345
- x = self.norm_layers_1[i](x + self.drop(self.attn_layers[i](x, x, attn_mask)))
346
- x = self.norm_layers_2[i](x + self.drop(self.ffn_layers[i](x, x_mask)))
347
-
348
- return x * x_mask
349
-
350
- class TextEncoder(torch.nn.Module):
351
- def __init__(self, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, embedding_dim, f0=True, onnx=False):
352
- super(TextEncoder, self).__init__()
353
- self.out_channels = out_channels
354
- self.hidden_channels = hidden_channels
355
- self.filter_channels = filter_channels
356
- self.n_heads = n_heads
357
- self.n_layers = n_layers
358
- self.kernel_size = kernel_size
359
- self.p_dropout = float(p_dropout)
360
- self.emb_phone = torch.nn.Linear(embedding_dim, hidden_channels)
361
- self.lrelu = torch.nn.LeakyReLU(0.1, inplace=True)
362
- if f0: self.emb_pitch = torch.nn.Embedding(256, hidden_channels)
363
- self.encoder = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), onnx=onnx)
364
- self.proj = torch.nn.Conv1d(hidden_channels, out_channels * 2, 1)
365
-
366
- def forward(self, phone, pitch, lengths):
367
- x = torch.transpose(self.lrelu(((self.emb_phone(phone) if pitch is None else (self.emb_phone(phone) + self.emb_pitch(pitch))) * math.sqrt(self.hidden_channels))), 1, -1)
368
- x_mask = torch.unsqueeze(sequence_mask(lengths, x.size(2)), 1).to(x.dtype)
369
- m, logs = torch.split((self.proj(self.encoder(x * x_mask, x_mask)) * x_mask), self.out_channels, dim=1)
370
-
371
- return m, logs, x_mask
372
-
373
- class PosteriorEncoder(torch.nn.Module):
374
- def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0):
375
- super(PosteriorEncoder, self).__init__()
376
- self.in_channels = in_channels
377
- self.out_channels = out_channels
378
- self.hidden_channels = hidden_channels
379
- self.kernel_size = kernel_size
380
- self.dilation_rate = dilation_rate
381
- self.n_layers = n_layers
382
- self.gin_channels = gin_channels
383
- self.pre = torch.nn.Conv1d(in_channels, hidden_channels, 1)
384
- self.enc = WaveNet(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
385
- self.proj = torch.nn.Conv1d(hidden_channels, out_channels * 2, 1)
386
-
387
- def forward(self, x, x_lengths, g = None):
388
- x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
389
- m, logs = torch.split((self.proj(self.enc((self.pre(x) * x_mask), x_mask, g=g)) * x_mask), self.out_channels, dim=1)
390
-
391
- return ((m + torch.randn_like(m) * torch.exp(logs)) * x_mask), m, logs, x_mask
392
-
393
- def remove_weight_norm(self):
394
- self.enc.remove_weight_norm()
395
-
396
- class Synthesizer(torch.nn.Module):
397
- def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, use_f0, text_enc_hidden_dim=768, vocoder="Default", checkpointing=False, onnx=False, **kwargs):
398
- super(Synthesizer, self).__init__()
399
- self.spec_channels = spec_channels
400
- self.inter_channels = inter_channels
401
- self.hidden_channels = hidden_channels
402
- self.filter_channels = filter_channels
403
- self.n_heads = n_heads
404
- self.n_layers = n_layers
405
- self.kernel_size = kernel_size
406
- self.p_dropout = float(p_dropout)
407
- self.resblock_kernel_sizes = resblock_kernel_sizes
408
- self.resblock_dilation_sizes = resblock_dilation_sizes
409
- self.upsample_rates = upsample_rates
410
- self.upsample_initial_channel = upsample_initial_channel
411
- self.upsample_kernel_sizes = upsample_kernel_sizes
412
- self.segment_size = segment_size
413
- self.gin_channels = gin_channels
414
- self.spk_embed_dim = spk_embed_dim
415
- self.use_f0 = use_f0
416
- self.enc_p = TextEncoder(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), text_enc_hidden_dim, f0=use_f0, onnx=onnx)
417
-
418
- if use_f0:
419
- if vocoder == "RefineGAN": self.dec = RefineGANGenerator(sample_rate=sr, upsample_rates=upsample_rates, num_mels=inter_channels, checkpointing=checkpointing)
420
- elif vocoder in ["MRF-HiFi-GAN", "MRF HiFi-GAN"]: self.dec = HiFiGANMRFGenerator(in_channel=inter_channels, upsample_initial_channel=upsample_initial_channel, upsample_rates=upsample_rates, upsample_kernel_sizes=upsample_kernel_sizes, resblock_kernel_sizes=resblock_kernel_sizes, resblock_dilations=resblock_dilation_sizes, gin_channels=gin_channels, sample_rate=sr, harmonic_num=8, checkpointing=checkpointing)
421
- else: self.dec = GeneratorNSF(inter_channels, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, sr=sr, checkpointing=checkpointing)
422
- else: self.dec = Generator(inter_channels, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
423
-
424
- self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
425
- self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels)
426
- self.emb_g = torch.nn.Embedding(self.spk_embed_dim, gin_channels)
427
-
428
- def remove_weight_norm(self):
429
- self.dec.remove_weight_norm()
430
- self.flow.remove_weight_norm()
431
- self.enc_q.remove_weight_norm()
432
-
433
- @torch.jit.ignore
434
- def forward(self, phone, phone_lengths, pitch = None, pitchf = None, y = None, y_lengths = None, ds = None):
435
- g = self.emb_g(ds).unsqueeze(-1)
436
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
437
-
438
- if y is not None:
439
- z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
440
- z_slice, ids_slice = rand_slice_segments(z, y_lengths, self.segment_size)
441
-
442
- return (self.dec(z_slice, slice_segments(pitchf, ids_slice, self.segment_size, 2), g=g) if self.use_f0 else self.dec(z_slice, g=g)), ids_slice, x_mask, y_mask, (z, self.flow(z, y_mask, g=g), m_p, logs_p, m_q, logs_q)
443
- else: return None, None, x_mask, None, (None, None, m_p, logs_p, None, None)
444
-
445
- @torch.jit.export
446
- def infer(self, phone, phone_lengths, pitch = None, nsff0 = None, sid = None, rate = None):
447
- g = self.emb_g(sid).unsqueeze(-1)
448
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
449
- z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
450
-
451
- if rate is not None:
452
- assert isinstance(rate, torch.Tensor)
453
- head = int(z_p.shape[2] * (1.0 - rate.item()))
454
- z_p = z_p[:, :, head:]
455
- x_mask = x_mask[:, :, head:]
456
- if self.use_f0: nsff0 = nsff0[:, head:]
457
-
458
- if self.use_f0:
459
- z = self.flow(z_p, x_mask, g=g, reverse=True)
460
- o = self.dec(z * x_mask, nsff0, g=g)
461
- else:
462
- z = self.flow(z_p, x_mask, g=g, reverse=True)
463
- o = self.dec(z * x_mask, g=g)
464
-
465
- return o, x_mask, (z, z_p, m_p, logs_p)
466
-
467
- class SynthesizerONNX(Synthesizer):
468
- def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, use_f0, text_enc_hidden_dim=768, vocoder="Default", checkpointing=False, **kwargs):
469
- super().__init__(spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, use_f0, text_enc_hidden_dim, vocoder, checkpointing, True)
470
- self.speaker_map = None
471
-
472
- def remove_weight_norm(self):
473
- self.dec.remove_weight_norm()
474
- self.flow.remove_weight_norm()
475
- self.enc_q.remove_weight_norm()
476
-
477
- def construct_spkmixmap(self, n_speaker):
478
- self.speaker_map = torch.zeros((n_speaker, 1, 1, self.gin_channels))
479
-
480
- for i in range(n_speaker):
481
- self.speaker_map[i] = self.emb_g(torch.LongTensor([[i]]))
482
-
483
- self.speaker_map = self.speaker_map.unsqueeze(0)
484
-
485
- def forward(self, phone, phone_lengths, g=None, rnd=None, pitch=None, nsff0=None, max_len=None):
486
- g = self.emb_g(g).unsqueeze(-1)
487
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
488
- z_p = (m_p + torch.exp(logs_p) * rnd) * x_mask
489
-
490
- return self.dec((self.flow(z_p, x_mask, g=g, reverse=True) * x_mask)[:, :, :max_len], nsff0, g=g) if self.use_f0 else self.dec((self.flow(z_p, x_mask, g=g, reverse=True) * x_mask)[:, :, :max_len], g=g)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/architectures/demucs_separator.py DELETED
@@ -1,180 +0,0 @@
1
- import os
2
- import sys
3
- import yaml
4
- import torch
5
-
6
- import numpy as np
7
- from hashlib import sha256
8
-
9
- sys.path.append(os.getcwd())
10
-
11
- from main.configs.config import Config
12
- from main.library.uvr5_separator import spec_utils, common_separator
13
- from main.library.uvr5_separator.demucs import hdemucs, states, apply
14
-
15
- translations = Config().translations
16
- sys.path.insert(0, os.path.join(os.getcwd(), "main", "library", "uvr5_separator"))
17
-
18
- DEMUCS_4_SOURCE_MAPPER = {common_separator.CommonSeparator.BASS_STEM: 0, common_separator.CommonSeparator.DRUM_STEM: 1, common_separator.CommonSeparator.OTHER_STEM: 2, common_separator.CommonSeparator.VOCAL_STEM: 3}
19
-
20
-
21
- class DemucsSeparator(common_separator.CommonSeparator):
22
- def __init__(self, common_config, arch_config):
23
- super().__init__(config=common_config)
24
- self.segment_size = arch_config.get("segment_size", "Default")
25
- self.shifts = arch_config.get("shifts", 2)
26
- self.overlap = arch_config.get("overlap", 0.25)
27
- self.segments_enabled = arch_config.get("segments_enabled", True)
28
- self.logger.debug(translations["demucs_info"].format(segment_size=self.segment_size, segments_enabled=self.segments_enabled))
29
- self.logger.debug(translations["demucs_info_2"].format(shifts=self.shifts, overlap=self.overlap))
30
- self.demucs_source_map = DEMUCS_4_SOURCE_MAPPER
31
- self.audio_file_path = None
32
- self.audio_file_base = None
33
- self.demucs_model_instance = None
34
- self.logger.info(translations["start_demucs"])
35
-
36
- def separate(self, audio_file_path):
37
- self.logger.debug(translations["start_separator"])
38
- source = None
39
- inst_source = {}
40
- self.audio_file_path = audio_file_path
41
- self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
42
- self.logger.debug(translations["prepare_mix"])
43
- mix = self.prepare_mix(self.audio_file_path)
44
- self.logger.debug(translations["demix"].format(shape=mix.shape))
45
- self.logger.debug(translations["cancel_mix"])
46
- self.demucs_model_instance = hdemucs.HDemucs(sources=["drums", "bass", "other", "vocals"])
47
- self.demucs_model_instance = get_demucs_model(name=os.path.splitext(os.path.basename(self.model_path))[0], repo=os.path.dirname(self.model_path))
48
- self.demucs_model_instance = apply.demucs_segments(self.segment_size, self.demucs_model_instance)
49
- self.demucs_model_instance.to(self.torch_device)
50
- self.demucs_model_instance.eval()
51
- self.logger.debug(translations["model_review"])
52
- source = self.demix_demucs(mix)
53
- del self.demucs_model_instance
54
- self.clear_gpu_cache()
55
- self.logger.debug(translations["del_gpu_cache_after_demix"])
56
- output_files = []
57
- self.logger.debug(translations["process_output_file"])
58
-
59
- if isinstance(inst_source, np.ndarray):
60
- self.logger.debug(translations["process_ver"])
61
- inst_source[self.demucs_source_map[common_separator.CommonSeparator.VOCAL_STEM]] = spec_utils.reshape_sources(inst_source[self.demucs_source_map[common_separator.CommonSeparator.VOCAL_STEM]], source[self.demucs_source_map[common_separator.CommonSeparator.VOCAL_STEM]])
62
- source = inst_source
63
-
64
- if isinstance(source, np.ndarray):
65
- source_length = len(source)
66
- self.logger.debug(translations["source_length"].format(source_length=source_length))
67
- self.logger.debug(translations["set_map"].format(part=source_length))
68
-
69
- match source_length:
70
- case 2: self.demucs_source_map = {common_separator.CommonSeparator.INST_STEM: 0, common_separator.CommonSeparator.VOCAL_STEM: 1}
71
- case 6: self.demucs_source_map = {common_separator.CommonSeparator.BASS_STEM: 0, common_separator.CommonSeparator.DRUM_STEM: 1, common_separator.CommonSeparator.OTHER_STEM: 2, common_separator.CommonSeparator.VOCAL_STEM: 3, common_separator.CommonSeparator.GUITAR_STEM: 4, common_separator.CommonSeparator.PIANO_STEM: 5}
72
- case _: self.demucs_source_map = DEMUCS_4_SOURCE_MAPPER
73
-
74
- self.logger.debug(translations["process_all_part"])
75
-
76
- for stem_name, stem_value in self.demucs_source_map.items():
77
- if self.output_single_stem is not None:
78
- if stem_name.lower() != self.output_single_stem.lower():
79
- self.logger.debug(translations["skip_part"].format(stem_name=stem_name, output_single_stem=self.output_single_stem))
80
- continue
81
-
82
- stem_path = os.path.join(f"{self.audio_file_base}_({stem_name})_{self.model_name}.{self.output_format.lower()}")
83
- self.final_process(stem_path, source[stem_value].T, stem_name)
84
- output_files.append(stem_path)
85
-
86
- return output_files
87
-
88
- def demix_demucs(self, mix):
89
- self.logger.debug(translations["starting_demix_demucs"])
90
- processed = {}
91
- mix = torch.tensor(mix, dtype=torch.float32)
92
- ref = mix.mean(0)
93
- mix = (mix - ref.mean()) / ref.std()
94
- mix_infer = mix
95
-
96
- with torch.no_grad():
97
- self.logger.debug(translations["model_infer"])
98
- sources = apply.apply_model(model=self.demucs_model_instance, mix=mix_infer[None], shifts=self.shifts, split=self.segments_enabled, overlap=self.overlap, static_shifts=1 if self.shifts == 0 else self.shifts, set_progress_bar=None, device=self.torch_device, progress=True)[0]
99
-
100
- sources = (sources * ref.std() + ref.mean()).cpu().numpy()
101
- sources[[0, 1]] = sources[[1, 0]]
102
-
103
- processed[mix] = sources[:, :, 0:None].copy()
104
- return np.concatenate([s[:, :, 0:None] for s in list(processed.values())], axis=-1)
105
-
106
- class LocalRepo:
107
- def __init__(self, root):
108
- self.root = root
109
- self.scan()
110
-
111
- def scan(self):
112
- self._models, self._checksums = {}, {}
113
- for filename in os.listdir(self.root):
114
- filepath = os.path.join(self.root, filename)
115
- if not os.path.isfile(filepath): continue
116
-
117
- if os.path.splitext(filename)[1] == ".th":
118
- stem = os.path.splitext(filename)[0]
119
-
120
- if "-" in stem:
121
- xp_sig, checksum = stem.split("-", 1)
122
- self._checksums[xp_sig] = checksum
123
- else: xp_sig = stem
124
-
125
- if xp_sig in self._models: raise RuntimeError(translations["del_all_but_one"].format(xp_sig=xp_sig))
126
- self._models[xp_sig] = filepath
127
-
128
- def has_model(self, sig):
129
- return sig in self._models
130
-
131
- def get_model(self, sig):
132
- try:
133
- file = self._models[sig]
134
- except KeyError:
135
- raise RuntimeError(translations["not_found_model_signature"].format(sig=sig))
136
-
137
- if sig in self._checksums: check_checksum(file, self._checksums[sig])
138
- return states.load_model(file)
139
-
140
- class BagOnlyRepo:
141
- def __init__(self, root, model_repo):
142
- self.root = root
143
- self.model_repo = model_repo
144
- self.scan()
145
-
146
- def scan(self):
147
- self._bags = {}
148
- for filename in os.listdir(self.root):
149
- filepath = os.path.join(self.root, filename)
150
-
151
- if os.path.isfile(filepath) and os.path.splitext(filename)[1] == ".yaml":
152
- stem = os.path.splitext(filename)[0]
153
- self._bags[stem] = filepath
154
-
155
- def get_model(self, name):
156
- try:
157
- yaml_file = self._bags[name]
158
- except KeyError:
159
- raise RuntimeError(translations["name_not_pretrained"].format(name=name))
160
-
161
- with open(yaml_file, 'r') as f:
162
- bag = yaml.safe_load(f)
163
-
164
- return apply.BagOfModels([self.model_repo.get_model(sig) for sig in bag["models"]], bag.get("weights"), bag.get("segment"))
165
-
166
- def check_checksum(path, checksum):
167
- sha = sha256()
168
-
169
- with open(path, "rb") as file:
170
- while 1:
171
- buf = file.read(2**20)
172
- if not buf: break
173
- sha.update(buf)
174
-
175
- actual_checksum = sha.hexdigest()[:len(checksum)]
176
- if actual_checksum != checksum: raise RuntimeError(translations["invalid_checksum"].format(path=path, checksum=checksum, actual_checksum=actual_checksum))
177
-
178
- def get_demucs_model(name, repo = None):
179
- model_repo = LocalRepo(repo)
180
- return (model_repo.get_model(name) if model_repo.has_model(name) else BagOnlyRepo(repo, model_repo).get_model(name)).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/architectures/fairseq.py DELETED
@@ -1,1480 +0,0 @@
1
- import re
2
- import sys
3
- import math
4
- import uuid
5
- import torch
6
- import types
7
- import contextlib
8
-
9
- import numpy as np
10
- import torch.nn.functional as F
11
-
12
- from torch import nn
13
- from omegaconf import DictConfig, open_dict
14
-
15
- class Dictionary:
16
- def __init__(self, *args, **kwargs):
17
- pass
18
-
19
- fairseq = types.ModuleType("fairseq")
20
- fairseq_data = types.ModuleType("fairseq.data")
21
- fairseq_data_dictionary = types.ModuleType("fairseq.data.dictionary")
22
- fairseq_data_dictionary.Dictionary = Dictionary
23
- fairseq.data = fairseq_data
24
- fairseq_data.dictionary = fairseq_data_dictionary
25
-
26
- sys.modules["fairseq"] = fairseq
27
- sys.modules["fairseq.data"] = fairseq_data
28
- sys.modules["fairseq.data.dictionary"] = fairseq_data_dictionary
29
-
30
- def load_model(filename):
31
- state = torch.load(filename, map_location="cpu")
32
-
33
- model = HubertModel(HubertConfig(**state['cfg']['model']))
34
- model.load_state_dict(state['model'], strict=False)
35
-
36
- return [model], Model_Config(state["cfg"]), Model_Config(state["cfg"]["task"])
37
-
38
- def softmax(x, dim, onnx_trace = False):
39
- return F.softmax(x.float(), dim=dim) if onnx_trace else F.softmax(x, dim=dim, dtype=torch.float32)
40
-
41
- def log_softmax(x, dim, onnx_trace = False):
42
- return F.log_softmax(x.float(), dim=dim) if onnx_trace else F.log_softmax(x, dim=dim, dtype=torch.float32)
43
-
44
- def eval_str_dict(x, type=dict):
45
- if x is None: return None
46
- if isinstance(x, str): x = eval(x)
47
- return x
48
-
49
- def with_incremental_state(cls):
50
- cls.__bases__ = (FairseqIncrementalState,) + tuple(b for b in cls.__bases__ if b != FairseqIncrementalState)
51
- return cls
52
-
53
- def quant_noise(module, p, block_size):
54
- if p <= 0: return module
55
- assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
56
-
57
- is_conv = module.weight.ndim == 4
58
- if not is_conv: assert (module.weight.size(1) % block_size == 0)
59
- else:
60
- if module.kernel_size == (1, 1): assert (module.in_channels % block_size == 0)
61
- else:
62
- k = module.kernel_size[0] * module.kernel_size[1]
63
- assert k % block_size == 0
64
-
65
- def _forward_pre_hook(mod, input):
66
- if mod.training:
67
- if not is_conv:
68
- weight = mod.weight
69
- in_features = weight.size(1)
70
- out_features = weight.size(0)
71
-
72
- mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
73
- mask.bernoulli_(p)
74
- mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
75
- else:
76
- weight = mod.weight
77
- in_channels = mod.in_channels
78
- out_channels = mod.out_channels
79
-
80
- if mod.kernel_size == (1, 1):
81
- mask = torch.zeros(int(in_channels // block_size * out_channels), device=weight.device)
82
- mask.bernoulli_(p)
83
- mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
84
- else:
85
- mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
86
- mask.bernoulli_(p)
87
- mask = (mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]))
88
-
89
- mask = mask.to(torch.bool)
90
- s = 1 / (1 - p)
91
- mod.weight.data = s * weight.masked_fill(mask, 0)
92
-
93
- module.register_forward_pre_hook(_forward_pre_hook)
94
- return module
95
-
96
- class FairseqDropout(nn.Module):
97
- def __init__(self, p, module_name=None):
98
- super().__init__()
99
- self.p = p
100
- self.module_name = module_name
101
- self.apply_during_inference = False
102
-
103
- def forward(self, x, inplace = False):
104
- return F.dropout(x, p=self.p, training=True, inplace=inplace) if self.p > 0 and (self.training or self.apply_during_inference) else x
105
-
106
- def make_generation_fast_(self, name, retain_dropout = False, retain_dropout_modules = None, **kwargs):
107
- if retain_dropout:
108
- if (retain_dropout_modules is None or self.module_name in retain_dropout_modules): self.apply_during_inference = True
109
-
110
- class FairseqIncrementalState(object):
111
- def __init__(self, *args, **kwargs):
112
- super().__init__(*args, **kwargs)
113
- self.init_incremental_state()
114
-
115
- def init_incremental_state(self):
116
- self._incremental_state_id = str(uuid.uuid4())
117
-
118
- def _get_full_incremental_state_key(self, key):
119
- return "{}.{}".format(self._incremental_state_id, key)
120
-
121
- def get_incremental_state(self, incremental_state, key):
122
- full_key = self._get_full_incremental_state_key(key)
123
- if incremental_state is None or full_key not in incremental_state: return None
124
- return incremental_state[full_key]
125
-
126
- def set_incremental_state(self, incremental_state, key, value):
127
- if incremental_state is not None: incremental_state[self._get_full_incremental_state_key(key)] = value
128
- return incremental_state
129
-
130
- class FairseqDecoder(nn.Module):
131
- def __init__(self, dictionary):
132
- super().__init__()
133
- self.dictionary = dictionary
134
- self.onnx_trace = False
135
- self.adaptive_softmax = None
136
-
137
- def forward(self, prev_output_tokens, encoder_out=None, **kwargs):
138
- x, extra = self.extract_features(prev_output_tokens, encoder_out=encoder_out, **kwargs)
139
- return self.output_layer(x), extra
140
-
141
- def extract_features(self, prev_output_tokens, encoder_out=None, **kwargs):
142
- pass
143
-
144
- def output_layer(self, features, **kwargs):
145
- pass
146
-
147
- def get_normalized_probs(self, net_output, log_probs, sample = None):
148
- return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
149
-
150
- def get_normalized_probs_scriptable(self, net_output, log_probs, sample = None):
151
- if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
152
- if sample is not None:
153
- assert "target" in sample
154
- target = sample["target"]
155
- else: target = None
156
-
157
- out = self.adaptive_softmax.get_log_prob(net_output[0], target=target)
158
- return out.exp_() if not log_probs else out
159
-
160
- logits = net_output[0]
161
- return log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) if log_probs else softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
162
-
163
- def max_positions(self):
164
- return 1e6
165
-
166
- def upgrade_state_dict_named(self, state_dict, name):
167
- return state_dict
168
-
169
- def prepare_for_onnx_export_(self):
170
- self.onnx_trace = True
171
-
172
- @with_incremental_state
173
- class FairseqIncrementalDecoder(FairseqDecoder):
174
- def __init__(self, dictionary):
175
- super().__init__(dictionary)
176
-
177
- def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
178
- pass
179
-
180
- def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
181
- pass
182
-
183
- def reorder_incremental_state(self, incremental_state, new_order):
184
- pass
185
-
186
- def reorder_incremental_state_scripting(self, incremental_state, new_order):
187
- for module in self.modules():
188
- if hasattr(module, "reorder_incremental_state"):
189
- result = module.reorder_incremental_state(incremental_state, new_order)
190
- if result is not None: incremental_state = result
191
-
192
- def set_beam_size(self, beam_size):
193
- if getattr(self, "_beam_size", -1) != beam_size:
194
- seen = set()
195
-
196
- def apply_set_beam_size(module):
197
- if (module != self and hasattr(module, "set_beam_size") and module not in seen):
198
- seen.add(module)
199
- module.set_beam_size(beam_size)
200
-
201
- self.apply(apply_set_beam_size)
202
- self._beam_size = beam_size
203
-
204
- class MultiheadAttention(FairseqIncrementalDecoder):
205
- def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, self_attention=False, encoder_decoder_attention=False, dictionary=None, q_noise=0.0, qn_block_size=8, xformers_att_config=None, xformers_blocksparse_layout=None, xformers_blocksparse_blocksize=16):
206
- super().__init__(dictionary)
207
- xformers_att_config = eval_str_dict(xformers_att_config)
208
- self.use_xformers = xformers_att_config is not None
209
- if self.use_xformers: raise ImportError
210
- self.embed_dim = embed_dim
211
- self.kdim = kdim if kdim is not None else embed_dim
212
- self.vdim = vdim if vdim is not None else embed_dim
213
- self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
214
- self.num_heads = num_heads
215
- self.dropout_module = FairseqDropout(dropout, module_name=self.__class__.__name__)
216
- self.head_dim = embed_dim // num_heads
217
- assert (self.head_dim * num_heads == self.embed_dim)
218
- self.scaling = self.head_dim**-0.5
219
- self.self_attention = self_attention
220
- self.encoder_decoder_attention = encoder_decoder_attention
221
- assert not self.self_attention or self.qkv_same_dim
222
- self.k_proj = quant_noise(nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size)
223
- self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
224
- self.q_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
225
- self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
226
- if add_bias_kv:
227
- self.bias_k = nn.Parameter(torch.Tensor(1, 1, embed_dim))
228
- self.bias_v = nn.Parameter(torch.Tensor(1, 1, embed_dim))
229
- else: self.bias_k = self.bias_v = None
230
- self.add_zero_attn = add_zero_attn
231
- self.beam_size = 1
232
- self.reset_parameters()
233
- self.onnx_trace = False
234
- self.skip_embed_dim_check = False
235
- self.init_incremental_state()
236
-
237
- def prepare_for_onnx_export_(self):
238
- self.onnx_trace = True
239
-
240
- def reset_parameters(self):
241
- if self.qkv_same_dim:
242
- nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
243
- nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
244
- nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
245
- else:
246
- nn.init.xavier_uniform_(self.k_proj.weight)
247
- nn.init.xavier_uniform_(self.v_proj.weight)
248
- nn.init.xavier_uniform_(self.q_proj.weight)
249
-
250
- nn.init.xavier_uniform_(self.out_proj.weight)
251
-
252
- if self.out_proj.bias is not None: nn.init.constant_(self.out_proj.bias, 0.0)
253
- if self.bias_k is not None: nn.init.xavier_normal_(self.bias_k)
254
- if self.bias_v is not None: nn.init.xavier_normal_(self.bias_v)
255
-
256
- def _get_reserve_head_index(self, num_heads_to_keep: int):
257
- k_proj_heads_norm, q_proj_heads_norm, v_proj_heads_norm = [], [], []
258
- for i in range(self.num_heads):
259
- start_idx = i * self.head_dim
260
- end_idx = (i + 1) * self.head_dim
261
- k_proj_heads_norm.append(torch.sum(torch.abs(self.k_proj.weight[start_idx:end_idx])).tolist() + torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist())
262
- q_proj_heads_norm.append(torch.sum(torch.abs(self.q_proj.weight[start_idx:end_idx])).tolist() + torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist())
263
- v_proj_heads_norm.append(torch.sum(torch.abs(self.v_proj.weight[start_idx:end_idx])).tolist() + torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist())
264
-
265
- heads_norm = []
266
- for i in range(self.num_heads):
267
- heads_norm.append(k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i])
268
-
269
- sorted_head_index = sorted(range(self.num_heads), key=lambda k: heads_norm[k], reverse=True)
270
- reserve_head_index = []
271
- for i in range(num_heads_to_keep):
272
- reserve_head_index.append((sorted_head_index[i] * self.head_dim, (sorted_head_index[i] + 1) * self.head_dim))
273
- return reserve_head_index
274
-
275
- def _adaptive_prune_heads(self, reserve_head_index):
276
- new_q_weight, new_q_bias, new_k_weight, new_k_bias, new_v_weight, new_v_bias, new_out_proj_weight = [], [], [], [], [], [], []
277
-
278
- for ele in reserve_head_index:
279
- start_idx, end_idx = ele
280
- new_q_weight.append(self.q_proj.weight[start_idx:end_idx])
281
- new_q_bias.append(self.q_proj.bias[start_idx:end_idx])
282
- new_k_weight.append(self.k_proj.weight[start_idx:end_idx])
283
- new_k_bias.append(self.k_proj.bias[start_idx:end_idx])
284
- new_v_weight.append(self.v_proj.weight[start_idx:end_idx])
285
- new_v_bias.append(self.v_proj.bias[start_idx:end_idx])
286
- new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx])
287
-
288
- new_q_weight = torch.cat(new_q_weight).detach()
289
- new_k_weight = torch.cat(new_k_weight).detach()
290
- new_v_weight = torch.cat(new_v_weight).detach()
291
- new_out_proj_weight = torch.cat(new_out_proj_weight, dim=-1).detach()
292
- new_q_weight.requires_grad = True
293
- new_k_weight.requires_grad = True
294
- new_v_weight.requires_grad = True
295
- new_out_proj_weight.requires_grad = True
296
- new_q_bias = torch.cat(new_q_bias).detach()
297
- new_q_bias.requires_grad = True
298
- new_k_bias = torch.cat(new_k_bias).detach()
299
- new_k_bias.requires_grad = True
300
- new_v_bias = torch.cat(new_v_bias).detach()
301
- new_v_bias.requires_grad = True
302
-
303
- self.q_proj.weight = nn.Parameter(new_q_weight)
304
- self.q_proj.bias = nn.Parameter(new_q_bias)
305
- self.k_proj.weight = nn.Parameter(new_k_weight)
306
- self.k_proj.bias = nn.Parameter(new_k_bias)
307
- self.v_proj.weight = nn.Parameter(new_v_weight)
308
- self.v_proj.bias = nn.Parameter(new_v_bias)
309
- self.out_proj.weight = nn.Parameter(new_out_proj_weight)
310
- self.num_heads = len(reserve_head_index)
311
- self.embed_dim = self.head_dim * self.num_heads
312
- self.q_proj.out_features = self.embed_dim
313
- self.k_proj.out_features = self.embed_dim
314
- self.v_proj.out_features = self.embed_dim
315
-
316
- def _set_skip_embed_dim_check(self):
317
- self.skip_embed_dim_check = True
318
-
319
- def _pad_masks(self, key_padding_mask, attn_mask):
320
- if attn_mask is not None:
321
- shape = attn_mask.size()[:-1] + torch.Size([1])
322
- attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(shape)], dim=-1)
323
-
324
- if key_padding_mask is not None:
325
- shape = key_padding_mask.size()[:-1] + torch.Size([1])
326
- key_padding_mask = torch.cat([key_padding_mask, key_padding_mask.new_zeros(shape)], dim=-1)
327
-
328
- return key_padding_mask, attn_mask
329
-
330
- def _add_bias(self, k, v, key_padding_mask, attn_mask, bsz):
331
- assert self.bias_k is not None or self.bias_v is not None
332
- key_padding_mask, attn_mask = self._pad_masks(key_padding_mask=key_padding_mask, attn_mask=attn_mask)
333
- return torch.cat([k, self.bias_k.repeat(1, bsz, 1)]), torch.cat([v, self.bias_v.repeat(1, bsz, 1)]), key_padding_mask, attn_mask
334
-
335
- def _append_zero_attn(self, k, v, key_padding_mask, attn_mask):
336
- zero_attn_shape = k.size()[:-2] + torch.Size([1]) + k.size()[-1:]
337
- key_padding_mask, attn_mask = self._pad_masks(key_padding_mask=key_padding_mask, attn_mask=attn_mask)
338
- return torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=-2), torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=-2), key_padding_mask, attn_mask
339
-
340
- def forward(self, query, key, value, key_padding_mask = None, incremental_state = None, need_weights = True, static_kv = False, attn_mask = None, before_softmax = False, need_head_weights = False):
341
- if need_head_weights: need_weights = True
342
- is_tpu = query.device.type == "xla"
343
- tgt_len, bsz, embed_dim = query.size()
344
- src_len = tgt_len
345
-
346
- if not self.skip_embed_dim_check: assert (embed_dim == self.embed_dim)
347
- assert list(query.size()) == [tgt_len, bsz, embed_dim]
348
-
349
- if key is not None:
350
- src_len, key_bsz, _ = key.size()
351
- if not torch.jit.is_scripting():
352
- assert value is not None
353
- assert src_len, key_bsz == value.shape[:2]
354
-
355
- if (not self.onnx_trace and not is_tpu and incremental_state is None and not static_kv and not torch.jit.is_scripting() and not self.skip_embed_dim_check):
356
- assert key is not None and value is not None
357
- return F.multi_head_attention_forward(query, key, value, self.embed_dim, self.num_heads, torch.empty([0]), torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), self.bias_k, self.bias_v, self.add_zero_attn, self.dropout_module.p, self.out_proj.weight, self.out_proj.bias, self.training or self.dropout_module.apply_during_inference, key_padding_mask.bool() if key_padding_mask is not None else None, need_weights, attn_mask, use_separate_proj_weight=True, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight)
358
-
359
- if incremental_state is not None:
360
- saved_state = self._get_input_buffer(incremental_state)
361
- if saved_state is not None and "prev_key" in saved_state:
362
- if static_kv:
363
- assert self.encoder_decoder_attention and not self.self_attention
364
- key = value = None
365
- else: saved_state = None
366
-
367
- if self.self_attention:
368
- q = self.q_proj(query)
369
- k = self.k_proj(query)
370
- v = self.v_proj(query)
371
- elif self.encoder_decoder_attention:
372
- q = self.q_proj(query)
373
- if key is None:
374
- assert value is None
375
- k = v = None
376
- else:
377
- if self.beam_size > 1 and bsz == key.size(1):
378
- key = key.view(key.size(0), -1, self.beam_size, key.size(2))[:, :, 0, :]
379
- if key_padding_mask is not None: key_padding_mask = key_padding_mask.view(-1, self.beam_size, key_padding_mask.size(1))[:, 0, :]
380
- k = self.k_proj(key)
381
- v = self.v_proj(key)
382
- else:
383
- assert key is not None and value is not None
384
- q = self.q_proj(query)
385
- k = self.k_proj(key)
386
- v = self.v_proj(value)
387
-
388
- q *= self.scaling
389
-
390
- if self.bias_k is not None:
391
- assert self.bias_v is not None
392
- k, v, attn_mask, key_padding_mask = self._add_bias(k, v, attn_mask, key_padding_mask, bsz)
393
-
394
- q = (q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1))
395
- kv_bsz = bsz
396
-
397
- if k is not None:
398
- kv_bsz = k.size(1)
399
- k = (k.contiguous().view(-1, kv_bsz * self.num_heads, self.head_dim).transpose(0, 1))
400
-
401
- if v is not None: v = (v.contiguous().view(-1, kv_bsz * self.num_heads, self.head_dim).transpose(0, 1))
402
-
403
- if saved_state is not None:
404
- if "prev_key" in saved_state:
405
- _prev_key = saved_state["prev_key"]
406
- assert _prev_key is not None
407
-
408
- kv_bsz = _prev_key.size(0)
409
- prev_key = _prev_key.view(kv_bsz * self.num_heads, -1, self.head_dim)
410
-
411
- if static_kv: k = prev_key
412
- else:
413
- assert k is not None
414
- k = torch.cat([prev_key, k], dim=1)
415
- src_len = k.size(1)
416
-
417
- if "prev_value" in saved_state:
418
- _prev_value = saved_state["prev_value"]
419
- assert _prev_value is not None or kv_bsz == _prev_value.size(0)
420
- prev_value = _prev_value.view(kv_bsz * self.num_heads, -1, self.head_dim)
421
-
422
- if static_kv: v = prev_value
423
- else:
424
- assert v is not None
425
- v = torch.cat([prev_value, v], dim=1)
426
-
427
- prev_key_padding_mask = None
428
- if "prev_key_padding_mask" in saved_state: prev_key_padding_mask = saved_state["prev_key_padding_mask"]
429
-
430
- assert k is not None and v is not None
431
- key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(key_padding_mask=key_padding_mask, prev_key_padding_mask=prev_key_padding_mask, batch_size=kv_bsz, src_len=k.size(1), static_kv=static_kv)
432
-
433
- saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1, self.head_dim)
434
- saved_state["prev_value"] = v.view(kv_bsz, self.num_heads, -1, self.head_dim)
435
- saved_state["prev_key_padding_mask"] = key_padding_mask
436
-
437
- assert incremental_state is not None
438
- incremental_state = self._set_input_buffer(incremental_state, saved_state)
439
-
440
- assert k is not None
441
- assert k.size(1) == src_len
442
-
443
- if key_padding_mask is not None and key_padding_mask.dim() == 0: key_padding_mask = None
444
-
445
- if key_padding_mask is not None:
446
- assert key_padding_mask.size(0) == kv_bsz
447
- assert key_padding_mask.size(1) == src_len
448
-
449
- if self.add_zero_attn:
450
- assert v is not None
451
- src_len += 1
452
- k, v, key_padding_mask, attn_mask = self._append_zero_attn(k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
453
-
454
- if self.encoder_decoder_attention and bsz != kv_bsz:
455
- attn_weights = torch.einsum("bxhtd,bhsd->bxhts", q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]), k.view((kv_bsz, self.num_heads) + k.size()[1:]))
456
- attn_weights = attn_weights.reshape((-1,) + attn_weights.size()[-2:])
457
- else: attn_weights = torch.bmm(q, k.transpose(1, 2))
458
-
459
- attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
460
- assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
461
-
462
- if attn_mask is not None:
463
- attn_mask = attn_mask.unsqueeze(0)
464
- if self.onnx_trace: attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
465
- attn_weights += attn_mask
466
-
467
- if key_padding_mask is not None:
468
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
469
- attn_weights = attn_weights.view(kv_bsz, -1, self.num_heads, tgt_len, src_len).masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(torch.bool), float("-inf")) if not is_tpu else attn_weights.transpose(0, 2).masked_fill(key_padding_mask, float("-inf")).transpose(0, 2)
470
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
471
-
472
- if before_softmax: return attn_weights, v
473
-
474
- attn_weights_float = softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace)
475
- attn_weights = attn_weights_float.type_as(attn_weights)
476
- attn_probs = self.dropout_module(attn_weights)
477
-
478
- assert v is not None
479
- attn = None
480
-
481
- if self.encoder_decoder_attention and bsz != kv_bsz:
482
- attn = torch.einsum("bxhts,bhsd->bxhtd", attn_probs.view((kv_bsz, -1, self.num_heads) + attn_probs.size()[1:]), v.view((kv_bsz, self.num_heads) + v.size()[1:]))
483
- attn = attn.reshape((-1,) + attn.size()[-2:])
484
- else: attn = torch.bmm(attn_probs, v)
485
- assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
486
-
487
- if self.onnx_trace and attn.size(1) == 1: attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim)
488
- else: attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
489
-
490
- attn = self.out_proj(attn)
491
- attn_weights = None
492
-
493
- if need_weights:
494
- attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
495
- if not need_head_weights: attn_weights = attn_weights.mean(dim=0)
496
-
497
- return attn, attn_weights
498
-
499
- @staticmethod
500
- def _append_prev_key_padding_mask(key_padding_mask, prev_key_padding_mask, batch_size, src_len, static_kv):
501
- if prev_key_padding_mask is not None and static_kv: new_key_padding_mask = prev_key_padding_mask
502
- elif prev_key_padding_mask is not None and key_padding_mask is not None: new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1)
503
- elif prev_key_padding_mask is not None:
504
- if src_len > prev_key_padding_mask.size(1):
505
- filler = torch.zeros((batch_size, src_len - prev_key_padding_mask.size(1)), device=prev_key_padding_mask.device)
506
- new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
507
- else: new_key_padding_mask = prev_key_padding_mask.float()
508
- elif key_padding_mask is not None:
509
- if src_len > key_padding_mask.size(1):
510
- filler = torch.zeros((batch_size, src_len - key_padding_mask.size(1)), device=key_padding_mask.device)
511
- new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
512
- else: new_key_padding_mask = key_padding_mask.float()
513
- else: new_key_padding_mask = prev_key_padding_mask
514
- return new_key_padding_mask
515
-
516
- @torch.jit.export
517
- def reorder_incremental_state(self, incremental_state, new_order):
518
- input_buffer = self._get_input_buffer(incremental_state)
519
- if input_buffer is not None:
520
- for k in input_buffer.keys():
521
- input_buffer_k = input_buffer[k]
522
- if input_buffer_k is not None:
523
- if self.encoder_decoder_attention:
524
- if input_buffer_k.size(0) * self.beam_size == new_order.size(0): return incremental_state
525
- elif self.beam_size > 1: input_buffer[k] = input_buffer_k.index_select(0, new_order.reshape(-1, self.beam_size)[:, 0] // self.beam_size)
526
- else: input_buffer[k] = input_buffer_k.index_select(0, new_order)
527
- else: input_buffer[k] = input_buffer_k.index_select(0, new_order)
528
- incremental_state = self._set_input_buffer(incremental_state, input_buffer)
529
- return incremental_state
530
-
531
- def set_beam_size(self, beam_size):
532
- self.beam_size = beam_size
533
-
534
- def _get_input_buffer(self, incremental_state):
535
- result = self.get_incremental_state(incremental_state, "attn_state")
536
- if result is not None: return result
537
- else: return {}
538
-
539
- def _set_input_buffer(self, incremental_state, buffer):
540
- return self.set_incremental_state(incremental_state, "attn_state", buffer)
541
-
542
- def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
543
- return attn_weights
544
-
545
- def upgrade_state_dict_named(self, state_dict, name):
546
- prefix = name + "." if name != "" else ""
547
- items_to_add = {}
548
- keys_to_remove = []
549
- for k in state_dict.keys():
550
- if k.endswith(prefix + "in_proj_weight"):
551
- dim = int(state_dict[k].shape[0] / 3)
552
- items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
553
- items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
554
- items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
555
- keys_to_remove.append(k)
556
- k_bias = prefix + "in_proj_bias"
557
- if k_bias in state_dict.keys():
558
- dim = int(state_dict[k].shape[0] / 3)
559
- items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
560
- items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][dim : 2 * dim]
561
- items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
562
- keys_to_remove.append(prefix + "in_proj_bias")
563
-
564
- for k in keys_to_remove:
565
- del state_dict[k]
566
-
567
- for key, value in items_to_add.items():
568
- state_dict[key] = value
569
-
570
- def init_bert_params(module):
571
- def normal_(data):
572
- data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
573
-
574
- if isinstance(module, nn.Linear):
575
- normal_(module.weight.data)
576
- if module.bias is not None: module.bias.data.zero_()
577
- if isinstance(module, nn.Embedding):
578
- normal_(module.weight.data)
579
- if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_()
580
- if isinstance(module, MultiheadAttention):
581
- normal_(module.q_proj.weight.data)
582
- normal_(module.k_proj.weight.data)
583
- normal_(module.v_proj.weight.data)
584
-
585
- def make_conv_pos(e, k, g):
586
- pos_conv = nn.Conv1d(e, e, kernel_size=k, padding=k // 2, groups=g)
587
- dropout = 0
588
-
589
- nn.init.normal_(pos_conv.weight, mean=0, std=math.sqrt((4 * (1.0 - dropout)) / (k * e)))
590
- nn.init.constant_(pos_conv.bias, 0)
591
-
592
- return nn.Sequential(nn.utils.parametrizations.weight_norm(pos_conv, name="weight", dim=2), SamePad(k), nn.GELU())
593
-
594
- def is_xla_tensor(tensor):
595
- return torch.is_tensor(tensor) and tensor.device.type == "xla"
596
-
597
- def index_put(tensor, indices, value):
598
- if is_xla_tensor(tensor):
599
- for _ in range(indices.dim(), tensor.dim()):
600
- indices = indices.unsqueeze(-1)
601
-
602
- if indices.size(-1) < tensor.size(-1): indices = indices.expand_as(tensor)
603
- tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
604
- else: tensor[indices] = value
605
-
606
- return tensor
607
-
608
- def pad_to_multiple(x, multiple, dim=-1, value=0):
609
- if x is None: return None, 0
610
- tsz = x.size(dim)
611
- m = tsz / multiple
612
- remainder = math.ceil(m) * multiple - tsz
613
- if m.is_integer(): return x, 0
614
- return F.pad(x, (*((0,) * (-1 - dim) * 2), 0, remainder), value=value), remainder
615
-
616
- def compute_mask_indices(shape, padding_mask, mask_prob, mask_length, mask_type = "static", mask_other = 0.0, min_masks = 0, no_overlap = False, min_space = 0, require_same_masks = True, mask_dropout = 0.0, add_masks = False, seed = None, epoch = None, indices = None, idc_select_ver = 1, num_mask_ver = 2):
617
- bsz, all_sz = shape
618
- mask = np.full((bsz, all_sz), False)
619
-
620
- if num_mask_ver == 1: all_num_mask = max(min_masks, int(mask_prob * all_sz / float(mask_length) + np.random.rand()))
621
- mask_idcs = []
622
-
623
- for i in range(bsz):
624
- seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6) if seed is not None and epoch is not None and indices is not None else None
625
- rng = np.random.default_rng(seed_i)
626
-
627
- if padding_mask is not None:
628
- sz = all_sz - padding_mask[i].long().sum().item()
629
- assert sz >= 0, sz
630
- else: sz = all_sz
631
-
632
- if num_mask_ver == 1: num_mask = max(min_masks, int(mask_prob * sz / float(mask_length) + np.random.rand())) if padding_mask is not None else all_num_mask
633
- elif num_mask_ver == 2: num_mask = max(min_masks, int(mask_prob * sz / float(mask_length) + rng.random()))
634
- else: raise ValueError
635
-
636
- if mask_type == "static": lengths = np.full(num_mask, mask_length)
637
- elif mask_type == "uniform": lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
638
- elif mask_type == "normal": lengths = [max(1, int(round(x))) for x in rng.normal(mask_length, mask_other, size=num_mask)]
639
- elif mask_type == "poisson": lengths = [int(round(x)) for x in rng.poisson(mask_length, size=num_mask)]
640
- else: raise Exception
641
-
642
- if sum(lengths) == 0:
643
- if mask_type == "static": raise ValueError
644
- else: lengths = [min(mask_length, sz - 1)]
645
-
646
- if no_overlap:
647
- mask_idc = []
648
-
649
- def arrange(s, e, length, keep_length):
650
- span_start = rng.randint(s, e - length)
651
- mask_idc.extend(span_start + i for i in range(length))
652
- new_parts = []
653
-
654
- if span_start - s - min_space >= keep_length: new_parts.append((s, span_start - min_space + 1))
655
- if e - span_start - length - min_space > keep_length: new_parts.append((span_start + length + min_space, e))
656
-
657
- return new_parts
658
-
659
- parts = [(0, sz)]
660
- min_length = min(lengths)
661
-
662
- for length in sorted(lengths, reverse=True):
663
- lens = np.fromiter((e - s if e - s >= length + min_space else 0 for s, e in parts), np.int32)
664
- l_sum = np.sum(lens)
665
- if l_sum == 0: break
666
- s, e = parts.pop(rng.choice(len(parts), p=lens / np.sum(lens)))
667
- parts.extend(arrange(s, e, length, min_length))
668
- mask_idc = np.asarray(mask_idc)
669
- else:
670
- if idc_select_ver == 1:
671
- min_len = min(lengths)
672
- if sz - min_len <= num_mask: min_len = sz - num_mask - 1
673
- mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
674
- elif idc_select_ver == 2: mask_idc = rng.choice(sz, num_mask, replace=False)
675
- else: raise ValueError
676
-
677
- mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
678
-
679
- mask_idc = np.unique(mask_idc[mask_idc < sz])
680
- if len(mask_idc) >= sz: raise ValueError
681
- mask_idcs.append(mask_idc)
682
-
683
- target_len = None
684
- if require_same_masks: target_len = max([len(m) for m in mask_idcs]) if add_masks else min([len(m) for m in mask_idcs])
685
-
686
- for i, mask_idc in enumerate(mask_idcs):
687
- if target_len is not None and len(mask_idc) > target_len: mask_idc = rng.choice(mask_idc, target_len, replace=False)
688
- mask[i, mask_idc] = True
689
-
690
- if target_len is not None and len(mask_idc) < target_len:
691
- to_mask = rng.choice(np.flatnonzero(~mask[i]), target_len - len(mask_idc), replace=False)
692
- mask[i, to_mask] = True
693
-
694
- if mask_dropout > 0:
695
- masked = np.flatnonzero(mask[i])
696
- mask[i, rng.choice(masked, np.rint(len(masked) * mask_dropout).astype(int), replace=False)] = False
697
-
698
- return mask
699
-
700
- def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
701
- return nn.LayerNorm(normalized_shape, eps, elementwise_affine)
702
-
703
- def prune_state_dict(state_dict, model_cfg):
704
- arch = None
705
- if model_cfg is not None: arch = (model_cfg._name if isinstance(model_cfg, DictConfig) else getattr(model_cfg, "arch", None))
706
-
707
- if not model_cfg or arch is None or arch == "ptt_transformer": return state_dict
708
-
709
- encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
710
- decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
711
-
712
- if not encoder_layers_to_keep and not decoder_layers_to_keep: return state_dict
713
-
714
- def create_pruning_pass(layers_to_keep, layer_name):
715
- keep_layers = sorted(int(layer_string) for layer_string in layers_to_keep.split(","))
716
- mapping_dict = {}
717
- for i in range(len(keep_layers)):
718
- mapping_dict[str(keep_layers[i])] = str(i)
719
-
720
- return {"substitution_regex": re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name)), "mapping_dict": mapping_dict}
721
-
722
- pruning_passes = []
723
- new_state_dict = {}
724
-
725
- if encoder_layers_to_keep: pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
726
- if decoder_layers_to_keep: pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
727
-
728
- for layer_name in state_dict.keys():
729
- match = re.search(r"\.layers\.(\d+)\.", layer_name)
730
- if not match:
731
- new_state_dict[layer_name] = state_dict[layer_name]
732
- continue
733
-
734
- original_layer_number = match.group(1)
735
- for pruning_pass in pruning_passes:
736
- if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass["substitution_regex"].search(layer_name):
737
- substitution_match = pruning_pass["substitution_regex"].search(layer_name)
738
- new_state_dict[(layer_name[: substitution_match.start(1)] + pruning_pass["mapping_dict"][original_layer_number] + layer_name[substitution_match.end(1) :])] = state_dict[layer_name]
739
-
740
- with open_dict(model_cfg) if isinstance(model_cfg, DictConfig) else contextlib.ExitStack():
741
- if hasattr(model_cfg, "encoder_layers_to_keep"): model_cfg.encoder_layers_to_keep = None
742
- if hasattr(model_cfg, "decoder_layers_to_keep"): model_cfg.decoder_layers_to_keep = None
743
-
744
- return new_state_dict
745
-
746
- def relu_squared(x):
747
- return F.relu(x).pow(2)
748
-
749
- def get_activation_fn(activation):
750
- def gelu(x):
751
- return nn.functional.gelu(x.float()).type_as(x)
752
-
753
- def gelu_accurate(x):
754
- if not hasattr(gelu_accurate, "_a"):
755
- gelu_accurate._a = math.sqrt(2 / math.pi)
756
- return (0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))))
757
-
758
- if activation == "relu": return F.relu
759
- elif activation == "relu_squared": return relu_squared
760
- elif activation == "gelu": return gelu
761
- elif activation == "gelu_fast": return gelu_accurate
762
- elif activation == "gelu_accurate": return gelu_accurate
763
- elif activation == "tanh": return torch.tanh
764
- elif activation == "linear": return lambda x: x
765
- elif activation == "swish": return nn.SiLU
766
- else: raise RuntimeError
767
-
768
- class SamePad(nn.Module):
769
- def __init__(self, kernel_size, causal=False):
770
- super().__init__()
771
- if causal: self.remove = kernel_size - 1
772
- else: self.remove = 1 if kernel_size % 2 == 0 else 0
773
-
774
- def forward(self, x):
775
- if self.remove > 0: x = x[:, :, : -self.remove]
776
- return x
777
-
778
- class TransformerSentenceEncoderLayer(nn.Module):
779
- def __init__(self, embedding_dim = 768, ffn_embedding_dim = 3072, num_attention_heads = 8, dropout = 0.1, attention_dropout = 0.1, activation_dropout = 0.1, activation_fn = "relu", layer_norm_first = False):
780
- super().__init__()
781
- self.embedding_dim = embedding_dim
782
- self.dropout = dropout
783
- self.activation_dropout = activation_dropout
784
- self.activation_fn = get_activation_fn(activation_fn)
785
- self.self_attn = MultiheadAttention(self.embedding_dim, num_attention_heads, dropout=attention_dropout, self_attention=True)
786
- self.dropout1 = nn.Dropout(dropout)
787
- self.dropout2 = nn.Dropout(self.activation_dropout)
788
- self.dropout3 = nn.Dropout(dropout)
789
- self.layer_norm_first = layer_norm_first
790
- self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
791
- self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
792
- self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
793
- self.final_layer_norm = LayerNorm(self.embedding_dim)
794
-
795
- def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None, need_weights=False, att_args=None):
796
- residual = x
797
-
798
- if self.layer_norm_first:
799
- x = self.self_attn_layer_norm(x)
800
- x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, attn_mask=self_attn_mask, need_weights=False)
801
- x = residual + self.dropout1(x)
802
- residual = x
803
- x = self.fc2(self.dropout2(self.activation_fn(self.fc1(self.final_layer_norm(x)))))
804
- layer_result = x
805
- x = residual + self.dropout3(x)
806
- else:
807
- x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, need_weights=False)
808
- x = self.self_attn_layer_norm(residual + self.dropout1(x))
809
- residual = x
810
- x = self.fc2(self.dropout2(self.activation_fn(self.fc1(x))))
811
- layer_result = x
812
- x = self.final_layer_norm(residual + self.dropout3(x))
813
-
814
- return x, (attn, layer_result)
815
-
816
- class AdapterFast(nn.Module):
817
- def __init__(self, adapter_num, input_dim, hidden_dim, act_fn):
818
- super().__init__()
819
- self.adapter_num = adapter_num
820
- self.input_dim = input_dim
821
- self.hidden_dim = hidden_dim
822
- self.W_a = nn.Parameter(torch.empty(adapter_num, hidden_dim, input_dim))
823
- self.W_b = nn.Parameter(torch.empty(adapter_num, input_dim, hidden_dim))
824
- self.b_a = nn.Parameter(torch.empty(adapter_num, hidden_dim))
825
- self.b_b = nn.Parameter(torch.empty(adapter_num, input_dim))
826
- self.ln_W = nn.Parameter(torch.empty(adapter_num, input_dim))
827
- self.ln_b = nn.Parameter(torch.empty(adapter_num, input_dim))
828
- self.act_fn = nn.Identity()
829
- if act_fn == "relu": self.act_fn = nn.ReLU()
830
- elif act_fn == "gelu": self.act_fn = nn.GELU()
831
- elif act_fn == "selu": self.act_fn = nn.SELU()
832
- else: raise ValueError
833
-
834
- self.input_dim = input_dim
835
- self.reset_parameters()
836
-
837
- def reset_parameters(self):
838
- for ii in range(self.adapter_num):
839
- nn.init.kaiming_uniform_(self.W_a[ii], a=math.sqrt(5))
840
- nn.init.kaiming_uniform_(self.W_b[ii], a=math.sqrt(5))
841
- fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_a[ii])
842
- bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
843
- nn.init.uniform_(self.b_a[ii], -bound, bound)
844
- fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_b[ii])
845
- bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
846
- nn.init.uniform_(self.b_b[ii], -bound, bound)
847
-
848
- nn.init.ones_(self.ln_W)
849
- nn.init.zeros_(self.ln_b)
850
-
851
- def forward(self, x, adapter_id):
852
- ii = adapter_id
853
- return F.linear(self.act_fn(F.linear(F.layer_norm(x, (self.input_dim, ), self.ln_W[ii], self.ln_b[ii]), self.W_a[ii], self.b_a[ii])), self.W_b[ii], self.b_b[ii])
854
-
855
- def extra_repr(self):
856
- return ('adapter={}, input_dim={}, hidden_dim={}'.format(self.adapter_num, self.input_dim, self.hidden_dim))
857
-
858
- class FeedForwardModule(nn.Module):
859
- def __init__(self, input_feat, hidden_units, dropout1, dropout2, activation_fn="swish", bias=True):
860
- super(FeedForwardModule, self).__init__()
861
- self.layer_norm = LayerNorm(input_feat)
862
- self.w_1 = nn.Linear(input_feat, hidden_units, bias=bias)
863
- self.w_2 = nn.Linear(hidden_units, input_feat, bias=bias)
864
- self.dropout1 = nn.Dropout(dropout1)
865
- self.dropout2 = nn.Dropout(dropout2)
866
- self.activation = get_activation_fn(activation_fn)(hidden_units)
867
-
868
- def forward(self, x):
869
- return self.dropout2(self.w_2(self.dropout1(self.activation(self.w_1(self.layer_norm(x))))))
870
-
871
- class ConvolutionModule(nn.Module):
872
- def __init__(self, embed_dim, channels, depthwise_kernel_size, dropout, activation_fn="swish", bias=False, export=False):
873
- super(ConvolutionModule, self).__init__()
874
- assert (depthwise_kernel_size - 1) % 2 == 0
875
- self.layer_norm = LayerNorm(embed_dim, export=export)
876
- self.pointwise_conv1 = nn.Conv1d(embed_dim, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias)
877
- self.glu = nn.GLU(dim=1)
878
- self.depthwise_conv = nn.Conv1d(channels, channels, depthwise_kernel_size, stride=1, padding=(depthwise_kernel_size - 1) // 2, groups=channels, bias=bias)
879
- self.batch_norm = nn.BatchNorm1d(channels)
880
- self.activation = get_activation_fn(activation_fn)(channels)
881
- self.pointwise_conv2 = nn.Conv1d(channels, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias)
882
- self.dropout = nn.Dropout(dropout)
883
-
884
- def forward(self, x):
885
- return self.dropout(self.pointwise_conv2(self.activation(self.batch_norm(self.depthwise_conv(self.glu(self.pointwise_conv1(self.layer_norm(x).transpose(1, 2)))))))).transpose(1, 2)
886
-
887
- def rotate_half(x):
888
- x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
889
- return torch.cat((-x2, x1), dim=x1.ndim - 1)
890
-
891
- def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
892
- cos, sin = (cos[offset : q.shape[0] + offset, ...], sin[offset : q.shape[0] + offset, ...])
893
- return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
894
-
895
- class RotaryPositionalEmbedding(nn.Module):
896
- def __init__(self, dim, base=10000, precision=torch.half):
897
- super().__init__()
898
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
899
- self.register_buffer("inv_freq", inv_freq)
900
- self.seq_len_cached = 0
901
- self.cos_cached = torch.empty(self.seq_len_cached, 1, 1, dim)
902
- self.sin_cached = torch.empty(self.seq_len_cached, 1, 1, dim)
903
- self.precision = precision
904
-
905
- def forward(self, x, seq_len = 0):
906
- if seq_len > self.seq_len_cached:
907
- self.seq_len_cached = seq_len
908
- freqs = torch.einsum("i,j->ij", torch.arange(seq_len, device=x.device).type_as(self.inv_freq), self.inv_freq)
909
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
910
- self.cos_cached = emb.cos().view(emb.size(0), 1, 1, emb.size(1))
911
- self.sin_cached = emb.sin().view(emb.size(0), 1, 1, emb.size(1))
912
- return self.cos_cached, self.sin_cached
913
-
914
- class ESPNETMultiHeadedAttention(nn.Module):
915
- def __init__(self, n_feat, n_head, dropout):
916
- super(ESPNETMultiHeadedAttention, self).__init__()
917
- assert n_feat % n_head == 0
918
- self.d_k = n_feat // n_head
919
- self.h = n_head
920
- self.linear_q = nn.Linear(n_feat, n_feat)
921
- self.linear_k = nn.Linear(n_feat, n_feat)
922
- self.linear_v = nn.Linear(n_feat, n_feat)
923
- self.linear_out = nn.Linear(n_feat, n_feat)
924
- self.attn = None
925
- self.dropout = nn.Dropout(p=dropout)
926
-
927
- def forward_qkv(self, query, key, value, **kwargs):
928
- n_batch = query.size(0)
929
- return self.linear_q(query).view(n_batch, -1, self.h, self.d_k).transpose(1, 2), self.linear_k(key).view(n_batch, -1, self.h, self.d_k).transpose(1, 2), self.linear_v(value).view(n_batch, -1, self.h, self.d_k).transpose(1, 2)
930
-
931
- def forward_attention(self, value, scores, mask):
932
- n_batch = value.size(0)
933
-
934
- if mask is not None:
935
- scores = scores.masked_fill(mask.unsqueeze(1).unsqueeze(2).to(bool), float("-inf"))
936
- self.attn = torch.softmax(scores, dim=-1)
937
- else: self.attn = torch.softmax(scores, dim=-1)
938
-
939
- return self.linear_out((torch.matmul(self.dropout(self.attn), value).transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)))
940
-
941
- def forward(self, query, key, value, key_padding_mask=None, **kwargs):
942
- q, k, v = self.forward_qkv(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1))
943
- return self.forward_attention(v, torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k), key_padding_mask).transpose(0, 1), None
944
-
945
- class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
946
- def __init__(self, n_feat, n_head, dropout, zero_triu=False):
947
- super().__init__(n_feat, n_head, dropout)
948
- self.zero_triu = zero_triu
949
- self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
950
- self.pos_bias_u = nn.Parameter(torch.zeros(self.h, self.d_k))
951
- self.pos_bias_v = nn.Parameter(torch.zeros(self.h, self.d_k))
952
- nn.init.xavier_uniform_(self.pos_bias_u)
953
- nn.init.xavier_uniform_(self.pos_bias_v)
954
-
955
- def rel_shift(self, x):
956
- x = torch.cat([torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype), x], dim=-1).view(*x.size()[:2], x.size(3) + 1, x.size(2))[:, :, 1:].view_as(x)[:, :, :, : x.size(-1) // 2 + 1]
957
- if self.zero_triu: x = x * torch.tril(torch.ones((x.size(2), x.size(3)), device=x.device), x.size(3) - x.size(2))[None, None, :, :]
958
- return x
959
-
960
- def forward(self, query, key, value, pos_emb, key_padding_mask=None, **kwargs):
961
- pos_emb = pos_emb.transpose(0, 1)
962
- q, k, v = self.forward_qkv(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1))
963
- q = q.transpose(1, 2)
964
-
965
- return self.forward_attention(v, (torch.matmul((q + self.pos_bias_u).transpose(1, 2), k.transpose(-2, -1)) + self.rel_shift(torch.matmul((q + self.pos_bias_v).transpose(1, 2), self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.h, self.d_k).transpose(1, 2).transpose(-2, -1)))) / math.sqrt(self.d_k), key_padding_mask).transpose(0, 1), None
966
-
967
- class RotaryPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
968
- def __init__(self, n_feat, n_head, dropout, precision, rotary_emd_base=10000):
969
- super().__init__(n_feat, n_head, dropout)
970
- precision = torch.float
971
- self.rotary_ndims = self.d_k
972
- if precision == "fp16": precision = torch.half
973
- self.rotary_emb = RotaryPositionalEmbedding(self.rotary_ndims, base=rotary_emd_base, precision=precision)
974
-
975
- def forward(self, query, key, value, key_padding_mask=None, **kwargs):
976
- T, B, C = value.size()
977
- query = query.view(T, B, self.h, self.d_k)
978
- key = key.view(T, B, self.h, self.d_k)
979
- value = value.view(T, B, self.h, self.d_k)
980
-
981
- cos, sin = self.rotary_emb(value, seq_len=T)
982
- query, key = apply_rotary_pos_emb(query, key, cos, sin, offset=0)
983
-
984
- query = query.view(T, B, self.h * self.d_k)
985
- key = key.view(T, B, self.h * self.d_k)
986
- value = value.view(T, B, self.h * self.d_k)
987
-
988
- q, k, v = self.forward_qkv(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1))
989
- return self.forward_attention(v, torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k), key_padding_mask).transpose(0, 1), None
990
-
991
- class ConformerEncoderLayer(nn.Module):
992
- def __init__(self, embed_dim, ffn_embed_dim, attention_heads, dropout, use_fp16, depthwise_conv_kernel_size=31, activation_fn="swish", attn_type=None, pos_enc_type="abs"):
993
- self.pos_enc_type = pos_enc_type
994
- super(ConformerEncoderLayer, self).__init__()
995
- self.ffn1 = FeedForwardModule(embed_dim, ffn_embed_dim, dropout, dropout)
996
- self.self_attn_layer_norm = LayerNorm(embed_dim, export=False)
997
- self.self_attn_dropout = nn.Dropout(dropout)
998
-
999
- if attn_type == "espnet":
1000
- if self.pos_enc_type == "rel_pos": self.self_attn = RelPositionMultiHeadedAttention(embed_dim, attention_heads, dropout=dropout)
1001
- elif self.pos_enc_type == "rope": self.self_attn = RotaryPositionMultiHeadedAttention(embed_dim, attention_heads, dropout=dropout, precision=use_fp16)
1002
- elif self.pos_enc_type == "abs": self.self_attn = ESPNETMultiHeadedAttention(embed_dim, attention_heads, dropout=dropout)
1003
- else: raise Exception
1004
- else: self.self_attn = MultiheadAttention(embed_dim, attention_heads, dropout=dropout)
1005
-
1006
- self.conv_module = ConvolutionModule(embed_dim=embed_dim, channels=embed_dim, depthwise_kernel_size=depthwise_conv_kernel_size, dropout=dropout, activation_fn=activation_fn)
1007
- self.ffn2 = FeedForwardModule(embed_dim, ffn_embed_dim, dropout, dropout, activation_fn=activation_fn)
1008
- self.final_layer_norm = LayerNorm(embed_dim, export=False)
1009
-
1010
- def forward(self, x, encoder_padding_mask, position_emb = None):
1011
- residual = x
1012
- x = self.ffn1(x) * 0.5 + residual
1013
- residual = x
1014
- x = self.self_attn_layer_norm(x)
1015
-
1016
- if self.pos_enc_type == "rel_pos": x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, pos_emb=position_emb, need_weights=False)
1017
- else: x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, need_weights=False)
1018
-
1019
- x = self.self_attn_dropout(x)
1020
- x = x + residual
1021
- residual = x
1022
- x = residual + self.conv_module(x.transpose(0, 1)).transpose(0, 1)
1023
- residual = x
1024
- x = self.ffn2(x)
1025
- layer_result = x
1026
- x = self.final_layer_norm(x * 0.5 + residual)
1027
-
1028
- return x, (attn, layer_result)
1029
-
1030
- class ConformerWav2Vec2EncoderLayer(ConformerEncoderLayer):
1031
- def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None, need_weights=False, att_args=None, position_emb=None):
1032
- return super().forward(x, self_attn_padding_mask, position_emb)
1033
-
1034
- class TransformerSentenceEncoderWithAdapterLayer(TransformerSentenceEncoderLayer):
1035
- def __init__(self, embedding_dim = 768, ffn_embedding_dim = 3072, num_attention_heads = 8, dropout = 0.1, attention_dropout = 0.1, activation_dropout = 0.1, activation_fn = "relu", layer_norm_first = False, adapter_num=201, adapter_dim=64, adapter_act_fn="relu"):
1036
- super().__init__(embedding_dim=embedding_dim, ffn_embedding_dim=ffn_embedding_dim, num_attention_heads=num_attention_heads, dropout=dropout, attention_dropout=attention_dropout, activation_dropout=activation_dropout, activation_fn=activation_fn, layer_norm_first=layer_norm_first)
1037
- self.adapter_num = adapter_num
1038
- self.adapter_dim = adapter_dim
1039
- self.adapter_layer = AdapterFast(adapter_num, self.embedding_dim, self.adapter_dim, adapter_act_fn)
1040
-
1041
- def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None, need_weights=False, att_args=None, corpus_key=None):
1042
-
1043
- x, (attn, layer_result) = super().forward(x=x, self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask, need_weights=need_weights, att_args=att_args)
1044
- assert corpus_key is not None
1045
- assert len(set(corpus_key)) == 1
1046
-
1047
- return x + self.adapter_layer(x, corpus_key[0]), (attn, layer_result)
1048
-
1049
- class TransposeLast(nn.Module):
1050
- def __init__(self, deconstruct_idx=None, tranpose_dim=-2):
1051
- super().__init__()
1052
- self.deconstruct_idx = deconstruct_idx
1053
- self.tranpose_dim = tranpose_dim
1054
-
1055
- def forward(self, x):
1056
- if self.deconstruct_idx is not None: x = x[self.deconstruct_idx]
1057
- return x.transpose(self.tranpose_dim, -1)
1058
-
1059
- class TransformerEncoder(nn.Module):
1060
- def build_encoder_layer(self, args, **kwargs):
1061
- if args.layer_type == "transformer": layer = TransformerSentenceEncoderLayer(embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first)
1062
- elif args.layer_type == "conformer": layer = ConformerWav2Vec2EncoderLayer(embed_dim=self.embedding_dim, ffn_embed_dim=args.encoder_ffn_embed_dim, attention_heads=args.encoder_attention_heads, dropout=args.dropout, depthwise_conv_kernel_size=args.depthwise_conv_kernel_size, activation_fn="swish", attn_type=args.attn_type, use_fp16=args.fp16, pos_enc_type="abs")
1063
- elif args.layer_type == "trf_adp":
1064
- use_adp = False
1065
- if args.adp_trf_idx == "all": use_adp = True
1066
- else:
1067
- if kwargs.get("layer_idx", None) in list(range(*[int(g) for g in args.adp_trf_idx.split(":")])): use_adp = True
1068
-
1069
- layer = TransformerSentenceEncoderWithAdapterLayer(embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first, adapter_num=args.adp_num, adapter_dim=args.adp_dim, adapter_act_fn=args.adp_act_fn) if use_adp else TransformerSentenceEncoderLayer(embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first,)
1070
-
1071
- return layer
1072
-
1073
- def __init__(self, args):
1074
- super().__init__()
1075
- self.dropout = args.dropout
1076
- self.embedding_dim = args.encoder_embed_dim
1077
- self.required_seq_len_multiple = args.required_seq_len_multiple
1078
- pos_conv_depth = getattr(args, "pos_conv_depth", 1)
1079
-
1080
- if pos_conv_depth > 1:
1081
- num_layers = args.pos_conv_depth
1082
- k = max(3, args.conv_pos // num_layers)
1083
-
1084
- def make_conv_block(e, k, g, l):
1085
- return nn.Sequential(*[nn.Sequential(nn.Conv1d(e, e, kernel_size=k, padding=k // 2, groups=g), SamePad(k), TransposeLast(), LayerNorm(e, elementwise_affine=False), TransposeLast(), nn.GELU()) for _ in range(l)])
1086
-
1087
- self.pos_conv = make_conv_block(self.embedding_dim, k, args.conv_pos_groups, num_layers)
1088
- else: self.pos_conv = make_conv_pos(self.embedding_dim, args.conv_pos, args.conv_pos_groups)
1089
-
1090
- self.layers = nn.ModuleList([self.build_encoder_layer(args, layer_idx=ii) for ii in range(args.encoder_layers)])
1091
- self.layer_norm_first = args.layer_norm_first
1092
- self.layer_norm = LayerNorm(self.embedding_dim)
1093
- self.layerdrop = args.encoder_layerdrop
1094
- self.apply(init_bert_params)
1095
-
1096
- def forward(self, x, padding_mask=None, layer=None, corpus_key=None):
1097
- x, layer_results = self.extract_features(x, padding_mask, layer, corpus_key=corpus_key)
1098
-
1099
- if self.layer_norm_first and layer is None: x = self.layer_norm(x)
1100
- return x, layer_results
1101
-
1102
- def extract_features(self, x, padding_mask=None, tgt_layer=None, min_layer=0, corpus_key=None):
1103
- if padding_mask is not None: x = index_put(x, padding_mask, 0)
1104
- x = x + self.pos_conv(x.transpose(1, 2)).transpose(1, 2)
1105
-
1106
- if not self.layer_norm_first: x = self.layer_norm(x)
1107
- x, pad_length = pad_to_multiple(x, self.required_seq_len_multiple, dim=-2, value=0)
1108
-
1109
- if pad_length > 0 and padding_mask is None:
1110
- padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
1111
- padding_mask[:, -pad_length:] = True
1112
- else: padding_mask, _ = pad_to_multiple(padding_mask, self.required_seq_len_multiple, dim=-1, value=True)
1113
-
1114
- x = F.dropout(x, p=self.dropout, training=self.training).transpose(0, 1)
1115
- layer_results = []
1116
- r = None
1117
-
1118
- for i, layer in enumerate(self.layers):
1119
- dropout_probability = np.random.random() if self.layerdrop > 0 else 1
1120
- if not self.training or (dropout_probability > self.layerdrop):
1121
- layer_check = layer
1122
-
1123
- if (corpus_key is None) or (not isinstance(layer_check, (TransformerSentenceEncoderWithAdapterLayer))): x, (z, lr) = layer(x, self_attn_padding_mask=padding_mask, need_weights=False)
1124
- else: x, (z, lr) = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, corpus_key=corpus_key)
1125
-
1126
- if i >= min_layer: layer_results.append((x, z, lr))
1127
- if i == tgt_layer:
1128
- r = x
1129
- break
1130
-
1131
- if r is not None: x = r
1132
- x = x.transpose(0, 1)
1133
-
1134
- if pad_length > 0:
1135
- x = x[:, :-pad_length]
1136
- def undo_pad(a, b, c):
1137
- return (a[:-pad_length], b[:-pad_length] if b is not None else b, c[:-pad_length])
1138
-
1139
- layer_results = [undo_pad(*u) for u in layer_results]
1140
-
1141
- return x, layer_results
1142
-
1143
- def max_positions(self):
1144
- return self.args.max_positions
1145
-
1146
- def upgrade_state_dict_named(self, state_dict, name):
1147
- return state_dict
1148
-
1149
- class Fp32GroupNorm(nn.GroupNorm):
1150
- def __init__(self, *args, **kwargs):
1151
- super().__init__(*args, **kwargs)
1152
-
1153
- def forward(self, input):
1154
- output = F.group_norm(input.float(), self.num_groups, self.weight.float() if self.weight is not None else None, self.bias.float() if self.bias is not None else None, self.eps)
1155
- return output.type_as(input)
1156
-
1157
- class Fp32LayerNorm(nn.LayerNorm):
1158
- def __init__(self, *args, **kwargs):
1159
- super().__init__(*args, **kwargs)
1160
-
1161
- def forward(self, input):
1162
- output = F.layer_norm(input.float(), self.normalized_shape, self.weight.float() if self.weight is not None else None, self.bias.float() if self.bias is not None else None, self.eps)
1163
- return output.type_as(input)
1164
-
1165
- class ConvFeatureExtractionModel(nn.Module):
1166
- def __init__(self, conv_layers, dropout = 0.0, mode = "default", conv_bias = False):
1167
- super().__init__()
1168
- assert mode in {"default", "layer_norm"}
1169
-
1170
- def block(n_in, n_out, k, stride, is_layer_norm=False, is_group_norm=False, conv_bias=False):
1171
- def make_conv():
1172
- conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
1173
- nn.init.kaiming_normal_(conv.weight)
1174
- return conv
1175
-
1176
- assert (is_layer_norm and is_group_norm) == False
1177
-
1178
- if is_layer_norm: return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.Sequential(TransposeLast(), Fp32LayerNorm(dim, elementwise_affine=True), TransposeLast()), nn.GELU())
1179
- elif is_group_norm: return nn.Sequential(make_conv(), nn.Dropout(p=dropout), Fp32GroupNorm(dim, dim, affine=True), nn.GELU())
1180
- else: return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
1181
-
1182
- in_d = 1
1183
- self.conv_layers = nn.ModuleList()
1184
- for i, cl in enumerate(conv_layers):
1185
- assert len(cl) == 3
1186
- (dim, k, stride) = cl
1187
- self.conv_layers.append(block(in_d, dim, k, stride, is_layer_norm=mode == "layer_norm", is_group_norm=mode == "default" and i == 0, conv_bias=conv_bias))
1188
- in_d = dim
1189
-
1190
- def forward(self, x):
1191
- x = x.unsqueeze(1)
1192
- for conv in self.conv_layers:
1193
- x = conv(x)
1194
-
1195
- return x
1196
-
1197
- class GradMultiply(torch.autograd.Function):
1198
- @staticmethod
1199
- def forward(ctx, x, scale):
1200
- ctx.scale = scale
1201
- res = x.new(x)
1202
- return res
1203
-
1204
- @staticmethod
1205
- def backward(ctx, grad):
1206
- return grad * ctx.scale, None
1207
-
1208
- class BaseFairseqModel(nn.Module):
1209
- def __init__(self):
1210
- super().__init__()
1211
- self._is_generation_fast = False
1212
-
1213
- def get_targets(self, sample, net_output):
1214
- return sample["target"]
1215
-
1216
- def extract_features(self, *args, **kwargs):
1217
- return self(*args, **kwargs)
1218
-
1219
- def load_state_dict(self, state_dict, strict=True, model_cfg = None, args = None):
1220
- self.upgrade_state_dict(state_dict)
1221
- new_state_dict = prune_state_dict(state_dict, model_cfg)
1222
- return super().load_state_dict(new_state_dict, strict)
1223
-
1224
- def upgrade_state_dict(self, state_dict):
1225
- self.upgrade_state_dict_named(state_dict, "")
1226
-
1227
- def upgrade_state_dict_named(self, state_dict, name):
1228
- assert state_dict is not None
1229
-
1230
- def do_upgrade(m, prefix):
1231
- if len(prefix) > 0: prefix += "."
1232
-
1233
- for n, c in m.named_children():
1234
- name = prefix + n
1235
- if hasattr(c, "upgrade_state_dict_named"): c.upgrade_state_dict_named(state_dict, name)
1236
- elif hasattr(c, "upgrade_state_dict"): c.upgrade_state_dict(state_dict)
1237
- do_upgrade(c, name)
1238
-
1239
- do_upgrade(self, name)
1240
-
1241
- def make_generation_fast_(self, **kwargs):
1242
- if self._is_generation_fast: return
1243
- self._is_generation_fast = True
1244
-
1245
- def apply_remove_weight_norm(module):
1246
- try:
1247
- nn.utils.remove_weight_norm(module)
1248
- except (AttributeError, ValueError):
1249
- return
1250
-
1251
- self.apply(apply_remove_weight_norm)
1252
-
1253
- def apply_make_generation_fast_(module, prefix):
1254
- if len(prefix) > 0: prefix += "."
1255
-
1256
- base_func = BaseFairseqModel.make_generation_fast_
1257
- for n, m in module.named_modules():
1258
- if (m != self and hasattr(m, "make_generation_fast_") and m.make_generation_fast_.__func__ is not base_func): m.make_generation_fast_(name=prefix + n, **kwargs)
1259
-
1260
- apply_make_generation_fast_(self, "")
1261
- self.eval()
1262
-
1263
- class HubertConfig:
1264
- def __init__(self, _name, label_rate, encoder_layers_1, logit_temp_ctr, num_negatives, cross_sample_negatives, ctr_layers, extractor_mode = "default", encoder_layers = 12, encoder_embed_dim = 768, encoder_ffn_embed_dim = 3072, encoder_attention_heads = 12, activation_fn = "gelu", layer_type = "transformer", dropout = 0.1, attention_dropout = 0.1, activation_dropout = 0.0, encoder_layerdrop = 0.0, dropout_input = 0.0, dropout_features = 0.0, final_dim = 0, untie_final_proj = False, layer_norm_first = False, conv_feature_layers = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", conv_bias = False, logit_temp = 0.1, target_glu = False, feature_grad_mult = 1.0, mask_length = 10, mask_prob = 0.65, mask_selection = "static", mask_other = 0.0, no_mask_overlap = False, mask_min_space = 1, mask_channel_length = 10, mask_channel_prob = 0.0, mask_channel_selection = "static", mask_channel_other = 0.0, no_mask_channel_overlap = False, mask_channel_min_space = 1, conv_pos = 128, conv_pos_groups = 16, conv_pos_batch_norm = False, latent_temp = (2, 0.5, 0.999995), skip_masked = False, skip_nomask = False, checkpoint_activations = False, required_seq_len_multiple = 2, depthwise_conv_kernel_size = 31, attn_type = "", pos_enc_type = "abs", fp16 = False):
1265
- self._name = _name
1266
- self.label_rate = label_rate
1267
- self.encoder_layers_1 = encoder_layers_1
1268
- self.logit_temp_ctr = logit_temp_ctr
1269
- self.num_negatives = num_negatives
1270
- self.cross_sample_negatives = cross_sample_negatives
1271
- self.ctr_layers = ctr_layers
1272
- self.extractor_mode = extractor_mode
1273
- self.encoder_layers = encoder_layers
1274
- self.encoder_embed_dim = encoder_embed_dim
1275
- self.encoder_ffn_embed_dim = encoder_ffn_embed_dim
1276
- self.encoder_attention_heads = encoder_attention_heads
1277
- self.activation_fn = activation_fn
1278
- self.layer_type = layer_type
1279
- self.dropout = dropout
1280
- self.attention_dropout = attention_dropout
1281
- self.activation_dropout = activation_dropout
1282
- self.encoder_layerdrop = encoder_layerdrop
1283
- self.dropout_input = encoder_layerdrop
1284
- self.dropout_features = dropout_features
1285
- self.final_dim = final_dim
1286
- self.untie_final_proj = untie_final_proj
1287
- self.layer_norm_first = layer_norm_first
1288
- self.conv_feature_layers = conv_feature_layers
1289
- self.conv_bias = conv_bias
1290
- self.logit_temp = logit_temp
1291
- self.target_glu = target_glu
1292
- self.feature_grad_mult = feature_grad_mult
1293
- self.mask_length = mask_length
1294
- self.mask_prob = mask_prob
1295
- self.mask_selection = mask_selection
1296
- self.mask_other = mask_other
1297
- self.no_mask_overlap = no_mask_overlap
1298
- self.mask_min_space = mask_min_space
1299
- self.mask_channel_length = mask_channel_length
1300
- self.mask_channel_prob = mask_channel_prob
1301
- self.mask_channel_selection = mask_channel_selection
1302
- self.mask_channel_other = mask_channel_other
1303
- self.no_mask_channel_overlap = no_mask_channel_overlap
1304
- self.mask_channel_min_space = mask_channel_min_space
1305
- self.conv_pos = conv_pos
1306
- self.conv_pos_groups = conv_pos_groups
1307
- self.conv_pos_batch_norm = conv_pos_batch_norm
1308
- self.latent_temp = latent_temp
1309
- self.skip_masked = skip_masked
1310
- self.skip_nomask = skip_nomask
1311
- self.checkpoint_activations = checkpoint_activations
1312
- self.required_seq_len_multiple = required_seq_len_multiple
1313
- self.depthwise_conv_kernel_size = depthwise_conv_kernel_size
1314
- self.attn_type = attn_type
1315
- self.pos_enc_type = pos_enc_type
1316
- self.fp16 = fp16
1317
-
1318
- class Model_Config(dict):
1319
- def __getattr__(*args):
1320
- val = dict.get(*args)
1321
- return Model_Config(val) if type(val) is dict else val
1322
-
1323
- __setattr__ = dict.__setitem__
1324
- __delattr__ = dict.__delitem__
1325
-
1326
- class HubertModel(BaseFairseqModel):
1327
- def __init__(self, cfg):
1328
- super().__init__()
1329
- feature_enc_layers = eval(cfg.conv_feature_layers)
1330
- self.embed = feature_enc_layers[-1][0]
1331
- self.feature_extractor = ConvFeatureExtractionModel(conv_layers=feature_enc_layers, dropout=0.0, mode=cfg.extractor_mode, conv_bias=cfg.conv_bias)
1332
- feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
1333
- self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / 16000
1334
- self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None)
1335
- self.mask_prob = cfg.mask_prob
1336
- self.mask_selection = cfg.mask_selection
1337
- self.mask_other = cfg.mask_other
1338
- self.mask_length = cfg.mask_length
1339
- self.no_mask_overlap = cfg.no_mask_overlap
1340
- self.mask_min_space = cfg.mask_min_space
1341
- self.mask_channel_prob = cfg.mask_channel_prob
1342
- self.mask_channel_selection = cfg.mask_channel_selection
1343
- self.mask_channel_other = cfg.mask_channel_other
1344
- self.mask_channel_length = cfg.mask_channel_length
1345
- self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
1346
- self.mask_channel_min_space = cfg.mask_channel_min_space
1347
- self.dropout_input = nn.Dropout(cfg.dropout_input)
1348
- self.dropout_features = nn.Dropout(cfg.dropout_features)
1349
- self.feature_grad_mult = cfg.feature_grad_mult
1350
- self.logit_temp = cfg.logit_temp
1351
- self.skip_masked = cfg.skip_masked
1352
- self.skip_nomask = cfg.skip_nomask
1353
- final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
1354
- self.mask_emb = nn.Parameter(torch.FloatTensor(cfg.encoder_embed_dim).uniform_())
1355
- self.encoder = TransformerEncoder(cfg)
1356
- self.layer_norm = LayerNorm(self.embed)
1357
- self.target_glu = None
1358
- if cfg.target_glu: self.target_glu = nn.Sequential(nn.Linear(final_dim, final_dim * 2), nn.GLU())
1359
- self.untie_final_proj = cfg.untie_final_proj
1360
- self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
1361
- self.num_classes = [504]
1362
- self.label_embs_concat = nn.Parameter(torch.FloatTensor(sum(self.num_classes), final_dim))
1363
- nn.init.uniform_(self.label_embs_concat)
1364
-
1365
- def upgrade_state_dict_named(self, state_dict, name):
1366
- super().upgrade_state_dict_named(state_dict, name)
1367
- return state_dict
1368
-
1369
- def apply_mask(self, x, padding_mask, target_list):
1370
- B, T, C = x.shape
1371
- if self.mask_prob > 0:
1372
- mask_indices = torch.from_numpy(compute_mask_indices((B, T), padding_mask, self.mask_prob, self.mask_length, self.mask_selection, self.mask_other, min_masks=2, no_overlap=self.no_mask_overlap, min_space=self.mask_min_space)).to(x.device)
1373
- x[mask_indices] = self.mask_emb
1374
- else: mask_indices = None
1375
-
1376
- if self.mask_channel_prob > 0: x[(torch.from_numpy(compute_mask_indices((B, C), None, self.mask_channel_prob, self.mask_channel_length, self.mask_channel_selection, self.mask_channel_other, no_overlap=self.no_mask_channel_overlap, min_space=self.mask_channel_min_space)).to(x.device).unsqueeze(1).expand(-1, T, -1))] = 0
1377
- return x, mask_indices
1378
-
1379
- def compute_nce(self, x, pos, negs):
1380
- neg_is_pos = (pos == negs).all(-1)
1381
- logits = torch.cosine_similarity(x.float(), torch.cat([pos.unsqueeze(0), negs], dim=0).float(), dim=-1).type_as(x)
1382
- logits /= self.logit_temp
1383
-
1384
- if neg_is_pos.any(): logits[1:][neg_is_pos] = float("-inf")
1385
- return logits.transpose(0, 1)
1386
-
1387
- def forward_features(self, source):
1388
- if self.feature_grad_mult > 0:
1389
- features = self.feature_extractor(source)
1390
- if self.feature_grad_mult != 1.0: features = GradMultiply.apply(features, self.feature_grad_mult)
1391
- else:
1392
- with torch.no_grad():
1393
- features = self.feature_extractor(source)
1394
- return features
1395
-
1396
- def forward_targets(self, features, target_list):
1397
- feat_tsz = features.size(2)
1398
- targ_tsz = min([t.size(1) for t in target_list])
1399
-
1400
- if self.feat2tar_ratio * feat_tsz > targ_tsz:
1401
- feat_tsz = int(targ_tsz / self.feat2tar_ratio)
1402
- features = features[..., :feat_tsz]
1403
-
1404
- return features, [t[:, (torch.arange(feat_tsz).float() * self.feat2tar_ratio).long()] for t in target_list]
1405
-
1406
- def forward_padding_mask(self, features, padding_mask):
1407
- extra = padding_mask.size(1) % features.size(1)
1408
- if extra > 0: padding_mask = padding_mask[:, :-extra]
1409
-
1410
- return padding_mask.view(padding_mask.size(0), features.size(1), -1).all(-1)
1411
-
1412
- def forward(self, source, target_list = None, padding_mask = None, mask = True, features_only = False, output_layer = None):
1413
- features = self.forward_features(source)
1414
- if target_list is not None: features, target_list = self.forward_targets(features, target_list)
1415
-
1416
- features_pen = features.float().pow(2).mean()
1417
-
1418
- features = self.layer_norm(features.transpose(1, 2))
1419
- unmasked_features = features.clone()
1420
-
1421
- if padding_mask is not None: padding_mask = self.forward_padding_mask(features, padding_mask)
1422
- if self.post_extract_proj is not None: features = self.post_extract_proj(features)
1423
-
1424
- features = self.dropout_input(features)
1425
- unmasked_features = self.dropout_features(unmasked_features)
1426
-
1427
- if mask: x, mask_indices = self.apply_mask(features, padding_mask, target_list)
1428
- else: x, mask_indices = features, None
1429
-
1430
- x, _ = self.encoder(x, padding_mask=padding_mask, layer=None if output_layer is None else output_layer - 1)
1431
- if features_only: return {"x": x, "padding_mask": padding_mask, "features": features}
1432
-
1433
- def compute_pred(proj_x, target, label_embs):
1434
- y = torch.index_select(label_embs, 0, target.long())
1435
- negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)
1436
-
1437
- if self.target_glu:
1438
- y = self.target_glu(y)
1439
- negs = self.target_glu(negs)
1440
-
1441
- return self.compute_nce(proj_x, y, negs)
1442
-
1443
- label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
1444
-
1445
- if not self.skip_masked:
1446
- masked_indices = torch.logical_and(~padding_mask, mask_indices)
1447
- proj_x_m = self.final_proj(x[masked_indices])
1448
- logit_m_list = [compute_pred(proj_x_m, t[masked_indices], label_embs_list[i]) for i, (proj_x_m, t) in enumerate(zip(proj_x_m.chunk(len(target_list), dim=-1) if self.untie_final_proj else [proj_x_m for _ in range(len(target_list))], target_list))]
1449
- else: logit_m_list = [None for _ in target_list]
1450
-
1451
- if not self.skip_nomask:
1452
- nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
1453
- proj_x_u = self.final_proj(x[nomask_indices])
1454
- logit_u_list = [compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i]) for i, (proj_x_u, t) in enumerate(zip(proj_x_u.chunk(len(target_list), dim=-1) if self.untie_final_proj else [proj_x_u for _ in range(len(target_list))], target_list))]
1455
- else: logit_u_list = [None for _ in target_list]
1456
-
1457
- return {"logit_m_list": logit_m_list, "logit_u_list": logit_u_list, "padding_mask": padding_mask, "features_pen": features_pen}
1458
-
1459
- def extract_features(self, source, padding_mask = None, mask = False, ret_conv = False, output_layer = None):
1460
- res = self.forward(source, padding_mask=padding_mask, mask=mask, features_only=True, output_layer=output_layer)
1461
- return res["features"] if ret_conv else res["x"], res["padding_mask"]
1462
-
1463
- def get_logits(self, net_output, is_masked=True):
1464
- return [x.float() for x in (net_output["logit_m_list"] if is_masked else net_output["logit_u_list"]) if x is not None]
1465
-
1466
- def get_targets(self, net_output, is_masked=True):
1467
- return [x.new_zeros(x.size(0), dtype=torch.long) for x in self.get_logits(net_output, is_masked)]
1468
-
1469
- def get_extra_losses(self, net_output):
1470
- extra_losses, names = [], []
1471
-
1472
- if "features_pen" in net_output:
1473
- extra_losses.append(net_output["features_pen"])
1474
- names.append("features_pen")
1475
-
1476
- return extra_losses, names
1477
-
1478
- def remove_pretraining_modules(self):
1479
- self.target_glu = None
1480
- self.final_proj = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/architectures/mdx_separator.py DELETED
@@ -1,320 +0,0 @@
1
- import os
2
- import sys
3
- import onnx
4
- import torch
5
- import platform
6
- import onnx2torch
7
-
8
- import numpy as np
9
- import onnxruntime as ort
10
-
11
- from tqdm import tqdm
12
-
13
- sys.path.append(os.getcwd())
14
-
15
- from main.configs.config import Config
16
- from main.library.uvr5_separator import spec_utils
17
- from main.library.uvr5_separator.common_separator import CommonSeparator
18
-
19
- translations = Config().translations
20
-
21
- class MDXSeparator(CommonSeparator):
22
- def __init__(self, common_config, arch_config):
23
- super().__init__(config=common_config)
24
- self.segment_size = arch_config.get("segment_size")
25
- self.overlap = arch_config.get("overlap")
26
- self.batch_size = arch_config.get("batch_size", 1)
27
- self.hop_length = arch_config.get("hop_length")
28
- self.enable_denoise = arch_config.get("enable_denoise")
29
- self.logger.debug(translations["mdx_info"].format(batch_size=self.batch_size, segment_size=self.segment_size))
30
- self.logger.debug(translations["mdx_info_2"].format(overlap=self.overlap, hop_length=self.hop_length, enable_denoise=self.enable_denoise))
31
- self.compensate = self.model_data["compensate"]
32
- self.dim_f = self.model_data["mdx_dim_f_set"]
33
- self.dim_t = 2 ** self.model_data["mdx_dim_t_set"]
34
- self.n_fft = self.model_data["mdx_n_fft_scale_set"]
35
- self.config_yaml = self.model_data.get("config_yaml", None)
36
- self.logger.debug(f"{translations['mdx_info_3']}: compensate = {self.compensate}, dim_f = {self.dim_f}, dim_t = {self.dim_t}, n_fft = {self.n_fft}")
37
- self.logger.debug(f"{translations['mdx_info_3']}: config_yaml = {self.config_yaml}")
38
- self.load_model()
39
- self.n_bins = 0
40
- self.trim = 0
41
- self.chunk_size = 0
42
- self.gen_size = 0
43
- self.stft = None
44
- self.primary_source = None
45
- self.secondary_source = None
46
- self.audio_file_path = None
47
- self.audio_file_base = None
48
-
49
- def load_model(self):
50
- self.logger.debug(translations["load_model_onnx"])
51
-
52
- if self.segment_size == self.dim_t:
53
- ort_session_options = ort.SessionOptions()
54
- ort_session_options.log_severity_level = 3 if self.log_level > 10 else 0
55
- ort_inference_session = ort.InferenceSession(self.model_path, providers=self.onnx_execution_provider, sess_options=ort_session_options)
56
- self.model_run = lambda spek: ort_inference_session.run(None, {"input": spek.cpu().numpy()})[0]
57
- self.logger.debug(translations["load_model_onnx_success"])
58
- else:
59
- self.model_run = onnx2torch.convert(onnx.load(self.model_path)) if platform.system() == 'Windows' else onnx2torch.convert(self.model_path)
60
- self.model_run.to(self.torch_device).eval()
61
- self.logger.debug(translations["onnx_to_pytorch"])
62
-
63
- def separate(self, audio_file_path):
64
- self.audio_file_path = audio_file_path
65
- self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
66
- self.logger.debug(translations["mix"].format(audio_file_path=self.audio_file_path))
67
- mix = self.prepare_mix(self.audio_file_path)
68
- self.logger.debug(translations["normalization_demix"])
69
- mix = spec_utils.normalize(wave=mix, max_peak=self.normalization_threshold)
70
- source = self.demix(mix)
71
- self.logger.debug(translations["mix_success"])
72
- output_files = []
73
- self.logger.debug(translations["process_output_file"])
74
-
75
- if not isinstance(self.primary_source, np.ndarray):
76
- self.logger.debug(translations["primary_source"])
77
- self.primary_source = spec_utils.normalize(wave=source, max_peak=self.normalization_threshold).T
78
-
79
- if not isinstance(self.secondary_source, np.ndarray):
80
- self.logger.debug(translations["secondary_source"])
81
- raw_mix = self.demix(mix, is_match_mix=True)
82
-
83
- if self.invert_using_spec:
84
- self.logger.debug(translations["invert_using_spec"])
85
- self.secondary_source = spec_utils.invert_stem(raw_mix, source)
86
- else:
87
- self.logger.debug(translations["invert_using_spec_2"])
88
- self.secondary_source = mix.T - source.T
89
-
90
- if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
91
- self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")
92
- self.logger.info(translations["save_secondary_stem_output_path"].format(stem_name=self.secondary_stem_name, stem_output_path=self.secondary_stem_output_path))
93
- self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
94
- output_files.append(self.secondary_stem_output_path)
95
-
96
- if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
97
- self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")
98
- if not isinstance(self.primary_source, np.ndarray): self.primary_source = source.T
99
-
100
- self.logger.info(translations["save_secondary_stem_output_path"].format(stem_name=self.primary_stem_name, stem_output_path=self.primary_stem_output_path))
101
- self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
102
- output_files.append(self.primary_stem_output_path)
103
-
104
- return output_files
105
-
106
- def initialize_model_settings(self):
107
- self.logger.debug(translations["starting_model"])
108
-
109
- self.n_bins = self.n_fft // 2 + 1
110
- self.trim = self.n_fft // 2
111
-
112
- self.chunk_size = self.hop_length * (self.segment_size - 1)
113
- self.gen_size = self.chunk_size - 2 * self.trim
114
-
115
- self.stft = STFT(self.logger, self.n_fft, self.hop_length, self.dim_f, self.torch_device)
116
-
117
- self.logger.debug(f"{translations['input_info']}: n_fft = {self.n_fft} hop_length = {self.hop_length} dim_f = {self.dim_f}")
118
- self.logger.debug(f"{translations['model_settings']}: n_bins = {self.n_bins}, Trim = {self.trim}, chunk_size = {self.chunk_size}, gen_size = {self.gen_size}")
119
-
120
- def initialize_mix(self, mix, is_ckpt=False):
121
- self.logger.debug(translations["initialize_mix"].format(is_ckpt=is_ckpt, shape=mix.shape))
122
-
123
- if mix.shape[0] != 2:
124
- error_message = translations["!=2"].format(shape=mix.shape[0])
125
- self.logger.error(error_message)
126
- raise ValueError(error_message)
127
-
128
- if is_ckpt:
129
- self.logger.debug(translations["process_check"])
130
- pad = self.gen_size + self.trim - (mix.shape[-1] % self.gen_size)
131
- self.logger.debug(f"{translations['cache']}: {pad}")
132
-
133
- mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, pad), dtype="float32")), 1)
134
-
135
- num_chunks = mixture.shape[-1] // self.gen_size
136
- self.logger.debug(translations["shape"].format(shape=mixture.shape, num_chunks=num_chunks))
137
-
138
- mix_waves = [mixture[:, i * self.gen_size : i * self.gen_size + self.chunk_size] for i in range(num_chunks)]
139
- else:
140
- self.logger.debug(translations["process_no_check"])
141
- mix_waves = []
142
- n_sample = mix.shape[1]
143
-
144
- pad = self.gen_size - n_sample % self.gen_size
145
- self.logger.debug(translations["n_sample_or_pad"].format(n_sample=n_sample, pad=pad))
146
-
147
- mix_p = np.concatenate((np.zeros((2, self.trim)), mix, np.zeros((2, pad)), np.zeros((2, self.trim))), 1)
148
- self.logger.debug(f"{translations['shape_2']}: {mix_p.shape}")
149
-
150
- i = 0
151
- while i < n_sample + pad:
152
- mix_waves.append(np.array(mix_p[:, i : i + self.chunk_size]))
153
-
154
- self.logger.debug(translations["process_part"].format(mix_waves=len(mix_waves), i=i, ii=i + self.chunk_size))
155
- i += self.gen_size
156
-
157
- mix_waves_tensor = torch.tensor(mix_waves, dtype=torch.float32).to(self.torch_device)
158
- self.logger.debug(translations["mix_waves_to_tensor"].format(shape=mix_waves_tensor.shape))
159
-
160
- return mix_waves_tensor, pad
161
-
162
- def demix(self, mix, is_match_mix=False):
163
- self.logger.debug(f"{translations['demix_is_match_mix']}: {is_match_mix}...")
164
- self.initialize_model_settings()
165
- self.logger.debug(f"{translations['mix_shape']}: {mix.shape}")
166
- tar_waves_ = []
167
-
168
- if is_match_mix:
169
- chunk_size = self.hop_length * (self.segment_size - 1)
170
- overlap = 0.02
171
- self.logger.debug(translations["chunk_size_or_overlap"].format(chunk_size=chunk_size, overlap=overlap))
172
- else:
173
- chunk_size = self.chunk_size
174
- overlap = self.overlap
175
- self.logger.debug(translations["chunk_size_or_overlap_standard"].format(chunk_size=chunk_size, overlap=overlap))
176
-
177
- gen_size = chunk_size - 2 * self.trim
178
- self.logger.debug(f"{translations['calc_size']}: {gen_size}")
179
-
180
- mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, gen_size + self.trim - ((mix.shape[-1]) % gen_size)), dtype="float32")), 1)
181
- self.logger.debug(f"{translations['mix_cache']}: {mixture.shape}")
182
-
183
- step = int((1 - overlap) * chunk_size)
184
- self.logger.debug(translations["step_or_overlap"].format(step=step, overlap=overlap))
185
-
186
- result = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
187
- divider = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
188
-
189
- total = 0
190
- total_chunks = (mixture.shape[-1] + step - 1) // step
191
- self.logger.debug(f"{translations['all_process_part']}: {total_chunks}")
192
-
193
- for i in tqdm(range(0, mixture.shape[-1], step), ncols=100, unit="f"):
194
- total += 1
195
- start = i
196
- end = min(i + chunk_size, mixture.shape[-1])
197
- self.logger.debug(translations["process_part_2"].format(total=total, total_chunks=total_chunks, start=start, end=end))
198
-
199
- chunk_size_actual = end - start
200
- window = None
201
-
202
- if overlap != 0:
203
- window = np.hanning(chunk_size_actual)
204
- window = np.tile(window[None, None, :], (1, 2, 1))
205
- self.logger.debug(translations["window"])
206
-
207
- mix_part_ = mixture[:, start:end]
208
-
209
- if end != i + chunk_size:
210
- pad_size = (i + chunk_size) - end
211
- mix_part_ = np.concatenate((mix_part_, np.zeros((2, pad_size), dtype="float32")), axis=-1)
212
-
213
- mix_waves = torch.tensor([mix_part_], dtype=torch.float32).to(self.torch_device).split(self.batch_size)
214
-
215
- total_batches = len(mix_waves)
216
- self.logger.debug(f"{translations['mix_or_batch']}: {total_batches}")
217
-
218
- with torch.no_grad():
219
- batches_processed = 0
220
-
221
- for mix_wave in mix_waves:
222
- batches_processed += 1
223
- self.logger.debug(f"{translations['mix_wave']} {batches_processed}/{total_batches}")
224
-
225
- tar_waves = self.run_model(mix_wave, is_match_mix=is_match_mix)
226
-
227
- if window is not None:
228
- tar_waves[..., :chunk_size_actual] *= window
229
- divider[..., start:end] += window
230
- else: divider[..., start:end] += 1
231
-
232
- result[..., start:end] += tar_waves[..., : end - start]
233
-
234
-
235
- self.logger.debug(translations["normalization_2"])
236
- tar_waves = result / divider
237
- tar_waves_.append(tar_waves)
238
-
239
- tar_waves = np.concatenate(np.vstack(tar_waves_)[:, :, self.trim : -self.trim], axis=-1)[:, : mix.shape[-1]]
240
-
241
- source = tar_waves[:, 0:None]
242
- self.logger.debug(f"{translations['tar_waves']}: {tar_waves.shape}")
243
-
244
- if not is_match_mix:
245
- source *= self.compensate
246
- self.logger.debug(translations["mix_match"])
247
-
248
- self.logger.debug(translations["mix_success"])
249
- return source
250
-
251
- def run_model(self, mix, is_match_mix=False):
252
- spek = self.stft(mix.to(self.torch_device))
253
- self.logger.debug(translations["stft_2"].format(shape=spek.shape))
254
-
255
- spek[:, :, :3, :] *= 0
256
-
257
- if is_match_mix:
258
- spec_pred = spek.cpu().numpy()
259
- self.logger.debug(translations["is_match_mix"])
260
- else:
261
- if self.enable_denoise:
262
- spec_pred_neg = self.model_run(-spek)
263
- spec_pred_pos = self.model_run(spek)
264
- spec_pred = (spec_pred_neg * -0.5) + (spec_pred_pos * 0.5)
265
- self.logger.debug(translations["enable_denoise"])
266
- else:
267
- spec_pred = self.model_run(spek)
268
- self.logger.debug(translations["no_denoise"])
269
-
270
- result = self.stft.inverse(torch.tensor(spec_pred).to(self.torch_device)).cpu().detach().numpy()
271
- self.logger.debug(f"{translations['stft']}: {result.shape}")
272
-
273
- return result
274
-
275
- class STFT:
276
- def __init__(self, logger, n_fft, hop_length, dim_f, device):
277
- self.logger = logger
278
- self.n_fft = n_fft
279
- self.hop_length = hop_length
280
- self.dim_f = dim_f
281
- self.device = device
282
- self.hann_window = torch.hann_window(window_length=self.n_fft, periodic=True)
283
-
284
- def __call__(self, input_tensor):
285
- is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
286
-
287
- if is_non_standard_device: input_tensor = input_tensor.cpu()
288
-
289
- batch_dimensions = input_tensor.shape[:-2]
290
- channel_dim, time_dim = input_tensor.shape[-2:]
291
-
292
- permuted_stft_output = torch.stft(input_tensor.reshape([-1, time_dim]), n_fft=self.n_fft, hop_length=self.hop_length, window=self.hann_window.to(input_tensor.device), center=True, return_complex=False).permute([0, 3, 1, 2])
293
- final_output = permuted_stft_output.reshape([*batch_dimensions, channel_dim, 2, -1, permuted_stft_output.shape[-1]]).reshape([*batch_dimensions, channel_dim * 2, -1, permuted_stft_output.shape[-1]])
294
-
295
- if is_non_standard_device: final_output = final_output.to(self.device)
296
- return final_output[..., : self.dim_f, :]
297
-
298
- def pad_frequency_dimension(self, input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins):
299
- return torch.cat([input_tensor, torch.zeros([*batch_dimensions, channel_dim, num_freq_bins - freq_dim, time_dim]).to(input_tensor.device)], -2)
300
-
301
- def calculate_inverse_dimensions(self, input_tensor):
302
- channel_dim, freq_dim, time_dim = input_tensor.shape[-3:]
303
-
304
- return input_tensor.shape[:-3], channel_dim, freq_dim, time_dim, self.n_fft // 2 + 1
305
-
306
- def prepare_for_istft(self, padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim):
307
- permuted_tensor = padded_tensor.reshape([*batch_dimensions, channel_dim // 2, 2, num_freq_bins, time_dim]).reshape([-1, 2, num_freq_bins, time_dim]).permute([0, 2, 3, 1])
308
-
309
- return permuted_tensor[..., 0] + permuted_tensor[..., 1] * 1.0j
310
-
311
- def inverse(self, input_tensor):
312
- is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
313
- if is_non_standard_device: input_tensor = input_tensor.cpu()
314
-
315
- batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins = self.calculate_inverse_dimensions(input_tensor)
316
- final_output = torch.istft(self.prepare_for_istft(self.pad_frequency_dimension(input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins), batch_dimensions, channel_dim, num_freq_bins, time_dim), n_fft=self.n_fft, hop_length=self.hop_length, window=self.hann_window.to(input_tensor.device), center=True).reshape([*batch_dimensions, 2, -1])
317
-
318
- if is_non_standard_device: final_output = final_output.to(self.device)
319
-
320
- return final_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/audioldm2/models.py DELETED
@@ -1,330 +0,0 @@
1
- import os
2
- import sys
3
- import torch
4
- import librosa
5
-
6
- import numpy as np
7
- import torch.nn.functional as F
8
-
9
- from scipy.signal import get_window
10
- from librosa.util import pad_center
11
- from diffusers import DDIMScheduler, AudioLDM2Pipeline
12
- from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
13
- from transformers import RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer
14
-
15
- sys.path.append(os.getcwd())
16
-
17
- from main.configs.config import Config
18
- from main.library.utils import check_audioldm2
19
-
20
- config = Config()
21
-
22
- class Pipeline(torch.nn.Module):
23
- def __init__(self, model_id, device, double_precision = False, token = None, *args, **kwargs):
24
- super().__init__(*args, **kwargs)
25
- self.model_id = model_id
26
- self.device = device
27
- self.double_precision = double_precision
28
- self.token = token
29
-
30
- def load_scheduler(self):
31
- pass
32
-
33
- def get_melspectrogram(self):
34
- pass
35
-
36
- def vae_encode(self, x):
37
- pass
38
-
39
- def vae_decode(self, x):
40
- pass
41
-
42
- def decode_to_mel(self, x):
43
- pass
44
-
45
- def setup_extra_inputs(self, *args, **kwargs):
46
- pass
47
-
48
- def encode_text(self, prompts, **kwargs):
49
- pass
50
-
51
- def get_variance(self, timestep, prev_timestep):
52
- pass
53
-
54
- def get_alpha_prod_t_prev(self, prev_timestep):
55
- pass
56
-
57
- def get_noise_shape(self, x0, num_steps):
58
- return (num_steps, self.model.unet.config.in_channels, x0.shape[-2], x0.shape[-1])
59
-
60
- def sample_xts_from_x0(self, x0, num_inference_steps = 50):
61
- alpha_bar = self.model.scheduler.alphas_cumprod
62
- sqrt_one_minus_alpha_bar = (1 - alpha_bar) ** 0.5
63
- timesteps = self.model.scheduler.timesteps.to(self.device)
64
- t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
65
- xts = torch.zeros(self.get_noise_shape(x0, num_inference_steps + 1)).to(x0.device)
66
- xts[0] = x0
67
-
68
- for t in reversed(timesteps):
69
- idx = num_inference_steps - t_to_idx[int(t)]
70
- xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t]
71
-
72
- return xts
73
-
74
- def get_zs_from_xts(self, xt, xtm1, noise_pred, t, eta = 0, numerical_fix = True, **kwargs):
75
- alpha_bar = self.model.scheduler.alphas_cumprod
76
-
77
- if self.model.scheduler.config.prediction_type == 'epsilon': pred_original_sample = (xt - (1 - alpha_bar[t]) ** 0.5 * noise_pred) / alpha_bar[t] ** 0.5
78
- elif self.model.scheduler.config.prediction_type == 'v_prediction': pred_original_sample = (alpha_bar[t] ** 0.5) * xt - ((1 - alpha_bar[t]) ** 0.5) * noise_pred
79
-
80
- prev_timestep = t - self.model.scheduler.config.num_train_timesteps // self.model.scheduler.num_inference_steps
81
- alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
82
- variance = self.get_variance(t, prev_timestep)
83
-
84
- if self.model.scheduler.config.prediction_type == 'epsilon': radom_noise_pred = noise_pred
85
- elif self.model.scheduler.config.prediction_type == 'v_prediction': radom_noise_pred = (alpha_bar[t] ** 0.5) * noise_pred + ((1 - alpha_bar[t]) ** 0.5) * xt
86
-
87
- mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + ((1 - alpha_prod_t_prev - eta * variance) ** (0.5) * radom_noise_pred)
88
- z = (xtm1 - mu_xt) / (eta * variance ** 0.5)
89
-
90
- if numerical_fix: xtm1 = mu_xt + (eta * variance ** 0.5)*z
91
- return z, xtm1, None
92
-
93
- def reverse_step_with_custom_noise(self, model_output, timestep, sample, variance_noise = None, eta = 0, **kwargs):
94
- prev_timestep = timestep - self.model.scheduler.config.num_train_timesteps // self.model.scheduler.num_inference_steps
95
- alpha_prod_t = self.model.scheduler.alphas_cumprod[timestep]
96
- alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
97
- beta_prod_t = 1 - alpha_prod_t
98
-
99
- if self.model.scheduler.config.prediction_type == 'epsilon': pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
100
- elif self.model.scheduler.config.prediction_type == 'v_prediction': pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output
101
-
102
- variance = self.get_variance(timestep, prev_timestep)
103
-
104
- if self.model.scheduler.config.prediction_type == 'epsilon': model_output_direction = model_output
105
- elif self.model.scheduler.config.prediction_type == 'v_prediction': model_output_direction = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
106
-
107
- prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + ((1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction)
108
-
109
- if eta > 0:
110
- if variance_noise is None: variance_noise = torch.randn(model_output.shape, device=self.device)
111
- prev_sample = prev_sample + (eta * variance ** (0.5) * variance_noise)
112
-
113
- return prev_sample
114
-
115
- def unet_forward(self, sample, timestep, encoder_hidden_states, class_labels = None, timestep_cond = None, attention_mask = None, cross_attention_kwargs = None, added_cond_kwargs = None, down_block_additional_residuals = None, mid_block_additional_residual = None, encoder_attention_mask = None, replace_h_space = None, replace_skip_conns = None, return_dict = True, zero_out_resconns = None):
116
- pass
117
-
118
- class STFT(torch.nn.Module):
119
- def __init__(self, fft_size, hop_size, window_size, window_type="hann"):
120
- super().__init__()
121
- self.fft_size = fft_size
122
- self.hop_size = hop_size
123
- self.window_size = window_size
124
- self.window_type = window_type
125
-
126
- scale = fft_size / hop_size
127
- fourier_basis = np.fft.fft(np.eye(fft_size))
128
-
129
- cutoff = fft_size // 2 + 1
130
- fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])])
131
-
132
- self.forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
133
- self.inverse_basis = torch.FloatTensor(np.linalg.pinv(scale * fourier_basis).T[:, None, :])
134
-
135
- if window_type:
136
- assert fft_size >= window_size
137
-
138
- fft_window = torch.from_numpy(pad_center(get_window(window_type, window_size, fftbins=True), size=fft_size)).float()
139
- self.forward_basis *= fft_window
140
- self.inverse_basis *= fft_window
141
-
142
- if not hasattr(self, "forward_basis"): self.register_buffer("forward_basis", self.forward_basis)
143
- if not hasattr(self, "inverse_basis"): self.register_buffer("inverse_basis", self.inverse_basis)
144
-
145
- def transform(self, signal):
146
- batch_size, num_samples = signal.shape
147
- transformed_signal = F.conv1d(F.pad(signal.view(batch_size, 1, num_samples).unsqueeze(1), (self.fft_size // 2, self.fft_size // 2, 0, 0), mode="reflect").squeeze(1), self.forward_basis, stride=self.hop_size, padding=0).cpu()
148
-
149
- cutoff = self.fft_size // 2 + 1
150
- real_part, imag_part = transformed_signal[:, :cutoff, :], transformed_signal[:, cutoff:, :]
151
-
152
- return torch.sqrt(real_part ** 2 + imag_part ** 2), torch.atan2(imag_part, real_part)
153
-
154
- class MelSpectrogramProcessor(torch.nn.Module):
155
- def __init__(self, fft_size, hop_size, window_size, num_mel_bins, sample_rate, fmin, fmax):
156
- super().__init__()
157
- self.num_mel_bins = num_mel_bins
158
- self.sample_rate = sample_rate
159
- self.stft_processor = STFT(fft_size, hop_size, window_size)
160
- self.register_buffer("mel_filter", torch.from_numpy(librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mel_bins, fmin=fmin, fmax=fmax)).float())
161
-
162
- def compute_mel_spectrogram(self, waveform, normalization_fn=torch.log):
163
- assert torch.min(waveform) >= -1
164
- assert torch.max(waveform) <= 1
165
-
166
- magnitudes, _ = self.stft_processor.transform(waveform)
167
- return normalization_fn(torch.clamp(torch.matmul(self.mel_filter, magnitudes), min=1e-5))
168
-
169
- class AudioLDM2(Pipeline):
170
- def __init__(self, *args, **kwargs):
171
- super().__init__(*args, **kwargs)
172
- self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, local_files_only=True, torch_dtype=torch.float16 if config.is_half else torch.float32).to(self.device)
173
-
174
- def load_scheduler(self):
175
- self.model.scheduler = DDIMScheduler.from_pretrained(self.model_id, subfolder="scheduler")
176
-
177
- def get_melspectrogram(self):
178
- return MelSpectrogramProcessor(fft_size=1024, hop_size=160, window_size=1024, num_mel_bins=64, sample_rate=16000, fmin=0, fmax=8000)
179
-
180
- def vae_encode(self, x):
181
- if x.shape[2] % 4: x = F.pad(x, (0, 0, 4 - (x.shape[2] % 4), 0))
182
- output = (self.model.vae.encode(x.half() if config.is_half else x.float()).latent_dist.mode() * self.model.vae.config.scaling_factor)
183
- return output.half() if config.is_half else output.float()
184
-
185
- def vae_decode(self, x):
186
- return self.model.vae.decode(1 / self.model.vae.config.scaling_factor * x).sample
187
-
188
- def decode_to_mel(self, x):
189
- tmp = self.model.mel_spectrogram_to_waveform(x[:, 0].detach().to(torch.float16 if config.is_half else torch.float32)).detach()
190
-
191
- if len(tmp.shape) == 1: tmp = tmp.unsqueeze(0)
192
- return tmp
193
-
194
- def encode_text(self, prompts, negative = False, save_compute = False, cond_length = 0, **kwargs):
195
- tokenizers, text_encoders = [self.model.tokenizer, self.model.tokenizer_2], [self.model.text_encoder, self.model.text_encoder_2]
196
- prompt_embeds_list, attention_mask_list = [], []
197
-
198
- for tokenizer, text_encoder in zip(tokenizers, text_encoders):
199
- text_inputs = tokenizer(prompts, padding="max_length" if (save_compute and negative) or isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast)) else True, max_length=tokenizer.model_max_length if (not save_compute) or ((not negative) or isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer))) else cond_length, truncation=True, return_tensors="pt")
200
- text_input_ids = text_inputs.input_ids
201
-
202
- attention_mask = text_inputs.attention_mask
203
- untruncated_ids = tokenizer(prompts, padding="longest", return_tensors="pt").input_ids
204
-
205
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1: -1])
206
-
207
- text_input_ids = text_input_ids.to(self.device)
208
- attention_mask = attention_mask.to(self.device)
209
-
210
- with torch.no_grad():
211
- if text_encoder.config.model_type == "clap":
212
- prompt_embeds = text_encoder.get_text_features(text_input_ids, attention_mask=attention_mask)
213
- prompt_embeds = prompt_embeds[:, None, :]
214
- attention_mask = attention_mask.new_ones((len(prompts), 1))
215
- else: prompt_embeds = text_encoder(text_input_ids, attention_mask=attention_mask)[0]
216
-
217
- prompt_embeds_list.append(prompt_embeds)
218
- attention_mask_list.append(attention_mask)
219
-
220
- projection_output = self.model.projection_model(hidden_states=prompt_embeds_list[0], hidden_states_1=prompt_embeds_list[1], attention_mask=attention_mask_list[0], attention_mask_1=attention_mask_list[1])
221
- generated_prompt_embeds = self.model.generate_language_model(projection_output.hidden_states, attention_mask=projection_output.attention_mask, max_new_tokens=None)
222
- prompt_embeds = prompt_embeds.to(dtype=self.model.text_encoder_2.dtype, device=self.device)
223
- return generated_prompt_embeds.to(dtype=self.model.language_model.dtype, device=self.device), prompt_embeds, (attention_mask.to(device=self.device) if attention_mask is not None else torch.ones(prompt_embeds.shape[:2], dtype=torch.long, device=self.device))
224
-
225
- def get_variance(self, timestep, prev_timestep):
226
- alpha_prod_t = self.model.scheduler.alphas_cumprod[timestep]
227
- alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
228
- return ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) * (1 - alpha_prod_t / alpha_prod_t_prev)
229
-
230
- def get_alpha_prod_t_prev(self, prev_timestep):
231
- return self.model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.model.scheduler.final_alpha_cumprod
232
-
233
- def unet_forward(self, sample, timestep, encoder_hidden_states, timestep_cond = None, class_labels = None, attention_mask = None, encoder_attention_mask = None, return_dict = True, cross_attention_kwargs = None, mid_block_additional_residual = None, replace_h_space = None, replace_skip_conns = None, zero_out_resconns = None):
234
- encoder_hidden_states_1 = class_labels
235
- class_labels = None
236
- encoder_attention_mask_1 = encoder_attention_mask
237
- encoder_attention_mask = None
238
- default_overall_up_factor = 2 ** self.model.unet.num_upsamplers
239
- forward_upsample_size = False
240
- upsample_size = None
241
-
242
- if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): forward_upsample_size = True
243
-
244
- if attention_mask is not None:
245
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
246
- attention_mask = attention_mask.unsqueeze(1)
247
-
248
- if encoder_attention_mask is not None:
249
- encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
250
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
251
-
252
- if encoder_attention_mask_1 is not None:
253
- encoder_attention_mask_1 = (1 - encoder_attention_mask_1.to(sample.dtype)) * -10000.0
254
- encoder_attention_mask_1 = encoder_attention_mask_1.unsqueeze(1)
255
-
256
- timesteps = timestep
257
- if not torch.is_tensor(timesteps):
258
- is_mps = sample.device.type == "mps"
259
-
260
- dtype = (torch.float16 if is_mps else torch.float32) if isinstance(timestep, float) else (torch.int16 if is_mps else torch.int32)
261
-
262
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
263
- elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device)
264
-
265
- emb = self.model.unet.time_embedding(self.model.unet.time_proj(timesteps.expand(sample.shape[0])).to(dtype=sample.dtype), timestep_cond)
266
- aug_emb = None
267
-
268
- if self.model.unet.class_embedding is not None:
269
- if class_labels is None: raise ValueError
270
-
271
- if self.model.unet.config.class_embed_type == "timestep": class_labels = self.model.unet.time_proj(class_labels).to(dtype=sample.dtype)
272
- class_emb = self.model.unet.class_embedding(class_labels).to(dtype=sample.dtype)
273
-
274
- if self.model.unet.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1)
275
- else: emb = emb + class_emb
276
-
277
- emb = emb + aug_emb if aug_emb is not None else emb
278
- if self.model.unet.time_embed_act is not None: emb = self.model.unet.time_embed_act(emb)
279
-
280
- sample = self.model.unet.conv_in(sample)
281
- down_block_res_samples = (sample,)
282
-
283
- for downsample_block in self.model.unet.down_blocks:
284
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, encoder_hidden_states_1=encoder_hidden_states_1, encoder_attention_mask_1=encoder_attention_mask_1)
285
- else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
286
-
287
- down_block_res_samples += res_samples
288
-
289
- if self.model.unet.mid_block is not None: sample = self.model.unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, encoder_hidden_states_1=encoder_hidden_states_1, encoder_attention_mask_1=encoder_attention_mask_1)
290
-
291
- if replace_h_space is None: h_space = sample.clone()
292
- else:
293
- h_space = replace_h_space
294
- sample = replace_h_space.clone()
295
-
296
- if mid_block_additional_residual is not None: sample = sample + mid_block_additional_residual
297
- extracted_res_conns = {}
298
-
299
- for i, upsample_block in enumerate(self.model.unet.up_blocks):
300
- is_final_block = i == len(self.model.unet.up_blocks) - 1
301
- res_samples = down_block_res_samples[-len(upsample_block.resnets):]
302
- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
303
-
304
- if replace_skip_conns is not None and replace_skip_conns.get(i): res_samples = replace_skip_conns.get(i)
305
-
306
- if zero_out_resconns is not None:
307
- if (type(zero_out_resconns) is int and i >= (zero_out_resconns - 1)) or type(zero_out_resconns) is list and i in zero_out_resconns: res_samples = [torch.zeros_like(x) for x in res_samples]
308
-
309
- extracted_res_conns[i] = res_samples
310
- if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:]
311
-
312
- if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, encoder_hidden_states_1=encoder_hidden_states_1, encoder_attention_mask_1=encoder_attention_mask_1)
313
- else: sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size)
314
-
315
- if self.model.unet.conv_norm_out: sample = self.model.unet.conv_act(self.model.unet.conv_norm_out(sample))
316
- sample = self.model.unet.conv_out(sample)
317
-
318
- if not return_dict: return (sample,)
319
- return UNet2DConditionOutput(sample=sample), h_space, extracted_res_conns
320
-
321
- def load_model(model, device):
322
- check_audioldm2(model)
323
-
324
- ldm_stable = AudioLDM2(model_id=os.path.join("assets", "models", "audioldm2", model), device=device, double_precision=False)
325
- ldm_stable.load_scheduler()
326
-
327
- if torch.cuda.is_available(): torch.cuda.empty_cache()
328
- elif torch.backends.mps.is_available(): torch.mps.empty_cache()
329
-
330
- return ldm_stable
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/audioldm2/utils.py DELETED
@@ -1,40 +0,0 @@
1
- import torch
2
- import librosa
3
- import torchaudio
4
-
5
- import numpy as np
6
-
7
- def compute_mel_spectrogram(audio, stft_processor):
8
- return stft_processor.compute_mel_spectrogram(torch.autograd.Variable(torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1), requires_grad=False)).squeeze(0).numpy().astype(np.float32)
9
-
10
- def pad_spectrogram(spectrogram, target_length=1024):
11
- pad_amount = target_length - spectrogram.shape[0]
12
- spectrogram = torch.nn.functional.pad(spectrogram, (0, 0, 0, pad_amount)) if pad_amount > 0 else spectrogram[:target_length, :]
13
-
14
- if spectrogram.size(-1) % 2 != 0: spectrogram = spectrogram[..., :-1]
15
- return spectrogram
16
-
17
- def pad_waveform(waveform, segment_length):
18
- waveform_length = waveform.shape[-1]
19
- assert waveform_length > 100
20
-
21
- if segment_length is None or waveform_length == segment_length: return waveform
22
- elif waveform_length > segment_length: return waveform[:, :segment_length]
23
-
24
- padded_waveform = np.zeros((1, segment_length))
25
- padded_waveform[:, :waveform_length] = waveform
26
- return padded_waveform
27
-
28
- def normalize(waveform):
29
- waveform -= np.mean(waveform)
30
- return (waveform / (np.max(np.abs(waveform)) + 1e-8)) * 0.5
31
-
32
- def process_audio(y, sr, segment_length):
33
- normalized_waveform = normalize(torchaudio.functional.resample(torch.from_numpy(y), orig_freq=sr, new_freq=16000).numpy())[None, ...]
34
- return 0.5 * (pad_waveform(normalized_waveform, segment_length) / np.max(np.abs(normalized_waveform)))
35
-
36
- def load_audio(audio_path, stft_processor, device=None):
37
- y, sr = librosa.load(audio_path, sr=None)
38
- duration = len(y) / sr
39
-
40
- return pad_spectrogram(torch.FloatTensor(compute_mel_spectrogram(torch.FloatTensor(process_audio(y, sr, int(duration * 102.4) * 160)[0, ...]), stft_processor).T), int(duration * 102.4)).unsqueeze(0).unsqueeze(0).to(device), duration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/predictors/CREPE.py DELETED
@@ -1,210 +0,0 @@
1
- import os
2
- import torch
3
- import librosa
4
- import functools
5
- import scipy.stats
6
-
7
- import numpy as np
8
-
9
- CENTS_PER_BIN, MAX_FMAX, PITCH_BINS, SAMPLE_RATE, WINDOW_SIZE = 20, 2006, 360, 16000, 1024
10
-
11
- class Crepe(torch.nn.Module):
12
- def __init__(self, model='full'):
13
- super().__init__()
14
- if model == 'full':
15
- in_channels = [1, 1024, 128, 128, 128, 256]
16
- out_channels = [1024, 128, 128, 128, 256, 512]
17
- self.in_features = 2048
18
- elif model == 'large':
19
- in_channels = [1, 768, 96, 96, 96, 192]
20
- out_channels = [768, 96, 96, 96, 192, 384]
21
- self.in_features = 1536
22
- elif model == 'medium':
23
- in_channels = [1, 512, 64, 64, 64, 128]
24
- out_channels = [512, 64, 64, 64, 128, 256]
25
- self.in_features = 1024
26
- elif model == 'small':
27
- in_channels = [1, 256, 32, 32, 32, 64]
28
- out_channels = [256, 32, 32, 32, 64, 128]
29
- self.in_features = 512
30
- elif model == 'tiny':
31
- in_channels = [1, 128, 16, 16, 16, 32]
32
- out_channels = [128, 16, 16, 16, 32, 64]
33
- self.in_features = 256
34
-
35
- kernel_sizes = [(512, 1)] + 5 * [(64, 1)]
36
- strides = [(4, 1)] + 5 * [(1, 1)]
37
-
38
- batch_norm_fn = functools.partial(torch.nn.BatchNorm2d, eps=0.0010000000474974513, momentum=0.0)
39
-
40
- self.conv1 = torch.nn.Conv2d(in_channels=in_channels[0], out_channels=out_channels[0], kernel_size=kernel_sizes[0], stride=strides[0])
41
- self.conv1_BN = batch_norm_fn(num_features=out_channels[0])
42
- self.conv2 = torch.nn.Conv2d(in_channels=in_channels[1], out_channels=out_channels[1], kernel_size=kernel_sizes[1], stride=strides[1])
43
- self.conv2_BN = batch_norm_fn(num_features=out_channels[1])
44
-
45
- self.conv3 = torch.nn.Conv2d(in_channels=in_channels[2], out_channels=out_channels[2], kernel_size=kernel_sizes[2], stride=strides[2])
46
- self.conv3_BN = batch_norm_fn(num_features=out_channels[2])
47
- self.conv4 = torch.nn.Conv2d(in_channels=in_channels[3], out_channels=out_channels[3], kernel_size=kernel_sizes[3], stride=strides[3])
48
- self.conv4_BN = batch_norm_fn(num_features=out_channels[3])
49
-
50
- self.conv5 = torch.nn.Conv2d(in_channels=in_channels[4], out_channels=out_channels[4], kernel_size=kernel_sizes[4], stride=strides[4])
51
- self.conv5_BN = batch_norm_fn(num_features=out_channels[4])
52
- self.conv6 = torch.nn.Conv2d(in_channels=in_channels[5], out_channels=out_channels[5], kernel_size=kernel_sizes[5], stride=strides[5])
53
- self.conv6_BN = batch_norm_fn(num_features=out_channels[5])
54
-
55
- self.classifier = torch.nn.Linear(in_features=self.in_features, out_features=PITCH_BINS)
56
-
57
- def forward(self, x, embed=False):
58
- x = self.embed(x)
59
- if embed: return x
60
-
61
- return torch.sigmoid(self.classifier(self.layer(x, self.conv6, self.conv6_BN).permute(0, 2, 1, 3).reshape(-1, self.in_features)))
62
-
63
- def embed(self, x):
64
- x = x[:, None, :, None]
65
-
66
- return self.layer(self.layer(self.layer(self.layer(self.layer(x, self.conv1, self.conv1_BN, (0, 0, 254, 254)), self.conv2, self.conv2_BN), self.conv3, self.conv3_BN), self.conv4, self.conv4_BN), self.conv5, self.conv5_BN)
67
-
68
- def layer(self, x, conv, batch_norm, padding=(0, 0, 31, 32)):
69
- return torch.nn.functional.max_pool2d(batch_norm(torch.nn.functional.relu(conv(torch.nn.functional.pad(x, padding)))), (2, 1), (2, 1))
70
-
71
- def viterbi(logits):
72
- if not hasattr(viterbi, 'transition'):
73
- xx, yy = np.meshgrid(range(360), range(360))
74
- transition = np.maximum(12 - abs(xx - yy), 0)
75
- viterbi.transition = transition / transition.sum(axis=1, keepdims=True)
76
-
77
- with torch.no_grad():
78
- probs = torch.nn.functional.softmax(logits, dim=1)
79
-
80
- bins = torch.tensor(np.array([librosa.sequence.viterbi(sequence, viterbi.transition).astype(np.int64) for sequence in probs.cpu().numpy()]), device=probs.device)
81
- return bins, bins_to_frequency(bins)
82
-
83
- def predict(audio, sample_rate, hop_length=None, fmin=50, fmax=MAX_FMAX, model='full', return_periodicity=False, batch_size=None, device='cpu', pad=True, providers=None, onnx=False):
84
- results = []
85
-
86
- if onnx:
87
- import onnxruntime as ort
88
-
89
- sess_options = ort.SessionOptions()
90
- sess_options.log_severity_level = 3
91
-
92
- session = ort.InferenceSession(os.path.join("assets", "models", "predictors", f"crepe_{model}.onnx"), sess_options=sess_options, providers=providers)
93
-
94
- for frames in preprocess(audio, sample_rate, hop_length, batch_size, device, pad):
95
- result = postprocess(torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: frames.cpu().numpy()})[0].transpose(1, 0)[None]), fmin, fmax, return_periodicity)
96
- results.append((result[0], result[1]) if isinstance(result, tuple) else result)
97
-
98
- del session
99
-
100
- if return_periodicity:
101
- pitch, periodicity = zip(*results)
102
- return torch.cat(pitch, 1), torch.cat(periodicity, 1)
103
-
104
- return torch.cat(results, 1)
105
- else:
106
- with torch.no_grad():
107
- for frames in preprocess(audio, sample_rate, hop_length, batch_size, device, pad):
108
- result = postprocess(infer(frames, model, device, embed=False).reshape(audio.size(0), -1, PITCH_BINS).transpose(1, 2), fmin, fmax, return_periodicity)
109
- results.append((result[0].to(audio.device), result[1].to(audio.device)) if isinstance(result, tuple) else result.to(audio.device))
110
-
111
- if return_periodicity:
112
- pitch, periodicity = zip(*results)
113
- return torch.cat(pitch, 1), torch.cat(periodicity, 1)
114
-
115
- return torch.cat(results, 1)
116
-
117
- def bins_to_frequency(bins):
118
- cents = CENTS_PER_BIN * bins + 1997.3794084376191
119
- return 10 * 2 ** ((cents + cents.new_tensor(scipy.stats.triang.rvs(c=0.5, loc=-CENTS_PER_BIN, scale=2 * CENTS_PER_BIN, size=cents.size()))) / 1200)
120
-
121
- def frequency_to_bins(frequency, quantize_fn=torch.floor):
122
- return quantize_fn(((1200 * torch.log2(frequency / 10)) - 1997.3794084376191) / CENTS_PER_BIN).int()
123
-
124
- def infer(frames, model='full', device='cpu', embed=False):
125
- if not hasattr(infer, 'model') or not hasattr(infer, 'capacity') or (hasattr(infer, 'capacity') and infer.capacity != model): load_model(device, model)
126
- infer.model = infer.model.to(device)
127
-
128
- return infer.model(frames, embed=embed)
129
-
130
- def load_model(device, capacity='full'):
131
- infer.capacity = capacity
132
- infer.model = Crepe(capacity)
133
- infer.model.load_state_dict(torch.load(os.path.join("assets", "models", "predictors", f"crepe_{capacity}.pth"), map_location=device))
134
- infer.model = infer.model.to(torch.device(device))
135
- infer.model.eval()
136
-
137
- def postprocess(probabilities, fmin=0, fmax=MAX_FMAX, return_periodicity=False):
138
- probabilities = probabilities.detach()
139
-
140
- probabilities[:, :frequency_to_bins(torch.tensor(fmin))] = -float('inf')
141
- probabilities[:, frequency_to_bins(torch.tensor(fmax), torch.ceil):] = -float('inf')
142
-
143
- bins, pitch = viterbi(probabilities)
144
-
145
- if not return_periodicity: return pitch
146
- return pitch, periodicity(probabilities, bins)
147
-
148
- def preprocess(audio, sample_rate, hop_length=None, batch_size=None, device='cpu', pad=True):
149
- hop_length = sample_rate // 100 if hop_length is None else hop_length
150
-
151
- if sample_rate != SAMPLE_RATE:
152
- audio = torch.tensor(librosa.resample(audio.detach().cpu().numpy().squeeze(0), orig_sr=sample_rate, target_sr=SAMPLE_RATE, res_type="soxr_vhq"), device=audio.device).unsqueeze(0)
153
- hop_length = int(hop_length * SAMPLE_RATE / sample_rate)
154
-
155
- if pad:
156
- total_frames = 1 + int(audio.size(1) // hop_length)
157
- audio = torch.nn.functional.pad(audio, (WINDOW_SIZE // 2, WINDOW_SIZE // 2))
158
- else: total_frames = 1 + int((audio.size(1) - WINDOW_SIZE) // hop_length)
159
-
160
- batch_size = total_frames if batch_size is None else batch_size
161
-
162
- for i in range(0, total_frames, batch_size):
163
- frames = torch.nn.functional.unfold(audio[:, None, None, max(0, i * hop_length):min(audio.size(1), (i + batch_size - 1) * hop_length + WINDOW_SIZE)], kernel_size=(1, WINDOW_SIZE), stride=(1, hop_length))
164
- frames = frames.transpose(1, 2).reshape(-1, WINDOW_SIZE).to(device)
165
- frames -= frames.mean(dim=1, keepdim=True)
166
- frames /= torch.max(torch.tensor(1e-10, device=frames.device), frames.std(dim=1, keepdim=True))
167
-
168
- yield frames
169
-
170
- def periodicity(probabilities, bins):
171
- probs_stacked = probabilities.transpose(1, 2).reshape(-1, PITCH_BINS)
172
- periodicity = probs_stacked.gather(1, bins.reshape(-1, 1).to(torch.int64))
173
-
174
- return periodicity.reshape(probabilities.size(0), probabilities.size(2))
175
-
176
- def mean(signals, win_length=9):
177
- assert signals.dim() == 2
178
-
179
- signals = signals.unsqueeze(1)
180
- mask = ~torch.isnan(signals)
181
- padding = win_length // 2
182
-
183
- ones_kernel = torch.ones(signals.size(1), 1, win_length, device=signals.device)
184
- avg_pooled = torch.nn.functional.conv1d(torch.where(mask, signals, torch.zeros_like(signals)), ones_kernel, stride=1, padding=padding) / torch.nn.functional.conv1d(mask.float(), ones_kernel, stride=1, padding=padding).clamp(min=1)
185
- avg_pooled[avg_pooled == 0] = float("nan")
186
-
187
- return avg_pooled.squeeze(1)
188
-
189
- def median(signals, win_length):
190
- assert signals.dim() == 2
191
-
192
- signals = signals.unsqueeze(1)
193
- mask = ~torch.isnan(signals)
194
- padding = win_length // 2
195
-
196
- x = torch.nn.functional.pad(torch.where(mask, signals, torch.zeros_like(signals)), (padding, padding), mode="reflect")
197
- mask = torch.nn.functional.pad(mask.float(), (padding, padding), mode="constant", value=0)
198
-
199
- x = x.unfold(2, win_length, 1)
200
- mask = mask.unfold(2, win_length, 1)
201
-
202
- x = x.contiguous().view(x.size()[:3] + (-1,))
203
- mask = mask.contiguous().view(mask.size()[:3] + (-1,))
204
-
205
- x_sorted, _ = torch.sort(torch.where(mask.bool(), x.float(), float("inf")).to(x), dim=-1)
206
-
207
- median_pooled = x_sorted.gather(-1, ((mask.sum(dim=-1) - 1) // 2).clamp(min=0).unsqueeze(-1).long()).squeeze(-1)
208
- median_pooled[torch.isinf(median_pooled)] = float("nan")
209
-
210
- return median_pooled.squeeze(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/predictors/FCPE.py DELETED
@@ -1,1000 +0,0 @@
1
- import os
2
- import math
3
- import torch
4
-
5
- import numpy as np
6
- import onnxruntime as ort
7
- import torch.nn.functional as F
8
-
9
- from torch import nn, einsum
10
- from functools import partial
11
- from torchaudio.transforms import Resample
12
- from einops import rearrange, repeat, pack, unpack
13
- from torch.nn.utils.parametrizations import weight_norm
14
-
15
- from librosa.filters import mel as librosa_mel_fn
16
-
17
- os.environ["LRU_CACHE_CAPACITY"] = "3"
18
-
19
- def exists(val):
20
- return val is not None
21
-
22
- def default(value, d):
23
- return value if exists(value) else d
24
-
25
- def empty(tensor):
26
- return tensor.numel() == 0
27
-
28
- def decrypt_model(input_path):
29
- from io import BytesIO
30
- from Crypto.Cipher import AES
31
- from Crypto.Util.Padding import unpad
32
-
33
- with open(input_path, "rb") as f:
34
- data = f.read()
35
-
36
- with open(os.path.join("main", "configs", "decrypt.bin"), "rb") as f:
37
- key = f.read()
38
-
39
- return BytesIO(unpad(AES.new(key, AES.MODE_CBC, data[:16]).decrypt(data[16:]), AES.block_size)).read()
40
-
41
- def l2_regularization(model, l2_alpha):
42
- l2_loss = []
43
- for module in model.modules():
44
- if type(module) is nn.Conv2d: l2_loss.append((module.weight**2).sum() / 2.0)
45
-
46
- return l2_alpha * sum(l2_loss)
47
-
48
- def pad_to_multiple(tensor, multiple, dim=-1, value=0):
49
- seqlen = tensor.shape[dim]
50
- m = seqlen / multiple
51
- if m.is_integer(): return False, tensor
52
- return True, F.pad(tensor, (*((0,) * (-1 - dim) * 2), 0, (math.ceil(m) * multiple - seqlen)), value = value)
53
-
54
- def look_around(x, backward = 1, forward = 0, pad_value = -1, dim = 2):
55
- t = x.shape[1]
56
- dims = (len(x.shape) - dim) * (0, 0)
57
- padded_x = F.pad(x, (*dims, backward, forward), value = pad_value)
58
- return torch.cat([padded_x[:, ind:(ind + t), ...] for ind in range(forward + backward + 1)], dim = dim)
59
-
60
- def rotate_half(x):
61
- x1, x2 = rearrange(x, 'b ... (r d) -> b ... r d', r = 2).unbind(dim = -2)
62
- return torch.cat((-x2, x1), dim = -1)
63
-
64
- def apply_rotary_pos_emb(q, k, freqs, scale = 1):
65
- q_len = q.shape[-2]
66
- q_freqs = freqs[..., -q_len:, :]
67
- inv_scale = scale ** -1
68
- if scale.ndim == 2: scale = scale[-q_len:, :]
69
- q = (q * q_freqs.cos() * scale) + (rotate_half(q) * q_freqs.sin() * scale)
70
- k = (k * freqs.cos() * inv_scale) + (rotate_half(k) * freqs.sin() * inv_scale)
71
-
72
- return q, k
73
-
74
- def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
75
- return torch.log(torch.clamp(x, min=clip_val) * C)
76
-
77
- def orthogonal_matrix_chunk(cols, qr_uniform_q=False, device=None):
78
- unstructured_block = torch.randn((cols, cols), device=device)
79
- q, r = torch.linalg.qr(unstructured_block.cpu(), mode="reduced")
80
- q, r = map(lambda t: t.to(device), (q, r))
81
- if qr_uniform_q:
82
- d = torch.diag(r, 0)
83
- q *= d.sign()
84
-
85
- return q.t()
86
-
87
- def linear_attention(q, k, v):
88
- return einsum("...ed,...nd->...ne", k, q) if v is None else einsum("...de,...nd,...n->...ne", einsum("...nd,...ne->...de", k, v), q, 1.0 / (einsum("...nd,...d->...n", q, k.sum(dim=-2).type_as(q)) + 1e-8))
89
-
90
- def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling=0, qr_uniform_q=False, device=None):
91
- nb_full_blocks = int(nb_rows / nb_columns)
92
- block_list = []
93
- for _ in range(nb_full_blocks):
94
- block_list.append(orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device))
95
-
96
- remaining_rows = nb_rows - nb_full_blocks * nb_columns
97
- if remaining_rows > 0: block_list.append(orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device)[:remaining_rows])
98
- if scaling == 0: multiplier = torch.randn((nb_rows, nb_columns), device=device).norm(dim=1)
99
- elif scaling == 1: multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device=device)
100
- else: raise ValueError(f"{scaling} != 0, 1")
101
-
102
- return torch.diag(multiplier) @ torch.cat(block_list)
103
-
104
- def calc_same_padding(kernel_size):
105
- pad = kernel_size // 2
106
- return (pad, pad - (kernel_size + 1) % 2)
107
-
108
- def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device=None):
109
- b, h, *_ = data.shape
110
-
111
- data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.0
112
- ratio = projection_matrix.shape[0] ** -0.5
113
- data_dash = torch.einsum("...id,...jd->...ij", (data_normalizer * data), repeat(projection_matrix, "j d -> b h j d", b=b, h=h).type_as(data))
114
- diag_data = ((torch.sum(data**2, dim=-1) / 2.0) * (data_normalizer**2)).unsqueeze(dim=-1)
115
-
116
- return (ratio * (torch.exp(data_dash - diag_data - torch.max(data_dash, dim=-1, keepdim=True).values) + eps) if is_query else ratio * (torch.exp(data_dash - diag_data + eps))).type_as(data)
117
-
118
- def torch_interp(x, xp, fp):
119
- sort_idx = torch.argsort(xp)
120
- xp = xp[sort_idx]
121
- fp = fp[sort_idx]
122
-
123
- right_idxs = torch.searchsorted(xp, x).clamp(max=len(xp) - 1)
124
- left_idxs = (right_idxs - 1).clamp(min=0)
125
- x_left = xp[left_idxs]
126
- y_left = fp[left_idxs]
127
-
128
- interp_vals = y_left + ((x - x_left) * (fp[right_idxs] - y_left) / (xp[right_idxs] - x_left))
129
- interp_vals[x < xp[0]] = fp[0]
130
- interp_vals[x > xp[-1]] = fp[-1]
131
-
132
- return interp_vals
133
-
134
- def batch_interp_with_replacement_detach(uv, f0):
135
- result = f0.clone()
136
- for i in range(uv.shape[0]):
137
- interp_vals = torch_interp(torch.where(uv[i])[-1], torch.where(~uv[i])[-1], f0[i][~uv[i]]).detach()
138
- result[i][uv[i]] = interp_vals
139
-
140
- return result
141
-
142
- def catch_none_args_must(x, func_name, warning_str):
143
- if x is None: raise ValueError(f'[Error] {warning_str}\n[Error] > {func_name}')
144
- else: return x
145
-
146
- def catch_none_args_opti(x, default, func_name, warning_str=None, level='WARN'):
147
- return default if x is None else x
148
-
149
- def spawn_wav2mel(args, device = None):
150
- _type = args.mel.type
151
- if (str(_type).lower() == 'none') or (str(_type).lower() == 'default'): _type = 'default'
152
- elif str(_type).lower() == 'stft': _type = 'stft'
153
- wav2mel = Wav2MelModule(sr=catch_none_args_opti(args.mel.sr, default=16000, func_name='torchfcpe.tools.spawn_wav2mel', warning_str='args.mel.sr is None'), n_mels=catch_none_args_opti(args.mel.num_mels, default=128, func_name='torchfcpe.tools.spawn_wav2mel', warning_str='args.mel.num_mels is None'), n_fft=catch_none_args_opti(args.mel.n_fft, default=1024, func_name='torchfcpe.tools.spawn_wav2mel', warning_str='args.mel.n_fft is None'), win_size=catch_none_args_opti(args.mel.win_size, default=1024, func_name='torchfcpe.tools.spawn_wav2mel', warning_str='args.mel.win_size is None'), hop_length=catch_none_args_opti(args.mel.hop_size, default=160, func_name='torchfcpe.tools.spawn_wav2mel', warning_str='args.mel.hop_size is None'), fmin=catch_none_args_opti(args.mel.fmin, default=0, func_name='torchfcpe.tools.spawn_wav2mel', warning_str='args.mel.fmin is None'), fmax=catch_none_args_opti(args.mel.fmax, default=8000, func_name='torchfcpe.tools.spawn_wav2mel', warning_str='args.mel.fmax is None'), clip_val=1e-05, mel_type=_type)
154
- device = catch_none_args_opti(device, default='cpu', func_name='torchfcpe.tools.spawn_wav2mel', warning_str='.device is None')
155
-
156
- return wav2mel.to(torch.device(device))
157
-
158
- def ensemble_f0(f0s, key_shift_list, tta_uv_penalty):
159
- device = f0s.device
160
- f0s = f0s / (torch.pow(2, torch.tensor(key_shift_list, device=device).to(device).unsqueeze(0).unsqueeze(0) / 12))
161
- notes = torch.log2(f0s / 440) * 12 + 69
162
- notes[notes < 0] = 0
163
-
164
- uv_penalty = tta_uv_penalty**2
165
- dp = torch.zeros_like(notes, device=device)
166
- backtrack = torch.zeros_like(notes, device=device).long()
167
- dp[:, 0, :] = (notes[:, 0, :] <= 0) * uv_penalty
168
-
169
- for t in range(1, notes.size(1)):
170
- penalty = torch.zeros([notes.size(0), notes.size(2), notes.size(2)], device=device)
171
- t_uv = notes[:, t, :] <= 0
172
- penalty += uv_penalty * t_uv.unsqueeze(1)
173
-
174
- t1_uv = notes[:, t - 1, :] <= 0
175
- l2 = torch.pow((notes[:, t - 1, :].unsqueeze(-1) - notes[:, t, :].unsqueeze(1)) * (~t1_uv).unsqueeze(-1) * (~t_uv).unsqueeze(1), 2) - 0.5
176
- l2 = l2 * (l2 > 0)
177
-
178
- penalty += l2
179
- penalty += t1_uv.unsqueeze(-1) * (~t_uv).unsqueeze(1) * uv_penalty * 2
180
-
181
- min_value, min_indices = torch.min(dp[:, t - 1, :].unsqueeze(-1) + penalty, dim=1)
182
- dp[:, t, :] = min_value
183
- backtrack[:, t, :] = min_indices
184
-
185
- t = f0s.size(1) - 1
186
- f0_result = torch.zeros_like(f0s[:, :, 0], device=device)
187
- min_indices = torch.argmin(dp[:, t, :], dim=-1)
188
-
189
- for i in range(0, t + 1):
190
- f0_result[:, t - i] = f0s[:, t - i, min_indices]
191
- min_indices = backtrack[:, t - i, min_indices]
192
-
193
- return f0_result.unsqueeze(-1)
194
-
195
- class LocalAttention(nn.Module):
196
- def __init__(self, window_size, causal = False, look_backward = 1, look_forward = None, dropout = 0., shared_qk = False, rel_pos_emb_config = None, dim = None, autopad = False, exact_windowsize = False, scale = None, use_rotary_pos_emb = True, use_xpos = False, xpos_scale_base = None):
197
- super().__init__()
198
- look_forward = default(look_forward, 0 if causal else 1)
199
- assert not (causal and look_forward > 0)
200
- self.scale = scale
201
- self.window_size = window_size
202
- self.autopad = autopad
203
- self.exact_windowsize = exact_windowsize
204
- self.causal = causal
205
- self.look_backward = look_backward
206
- self.look_forward = look_forward
207
- self.dropout = nn.Dropout(dropout)
208
- self.shared_qk = shared_qk
209
- self.rel_pos = None
210
- self.use_xpos = use_xpos
211
- if use_rotary_pos_emb and (exists(rel_pos_emb_config) or exists(dim)):
212
- if exists(rel_pos_emb_config): dim = rel_pos_emb_config[0]
213
- self.rel_pos = SinusoidalEmbeddings(dim, use_xpos = use_xpos, scale_base = default(xpos_scale_base, window_size // 2))
214
-
215
- def forward(self, q, k, v, mask = None, input_mask = None, attn_bias = None, window_size = None):
216
- mask = default(mask, input_mask)
217
- assert not (exists(window_size) and not self.use_xpos)
218
-
219
- _, autopad, pad_value, window_size, causal, look_backward, look_forward, shared_qk = q.shape, self.autopad, -1, default(window_size, self.window_size), self.causal, self.look_backward, self.look_forward, self.shared_qk
220
- (q, packed_shape), (k, _), (v, _) = map(lambda t: pack([t], '* n d'), (q, k, v))
221
-
222
- if autopad:
223
- orig_seq_len = q.shape[1]
224
- (_, q), (_, k), (_, v) = map(lambda t: pad_to_multiple(t, self.window_size, dim = -2), (q, k, v))
225
-
226
- b, n, dim_head, device, dtype = *q.shape, q.device, q.dtype
227
- scale = default(self.scale, dim_head ** -0.5)
228
-
229
- assert (n % window_size) == 0
230
- windows = n // window_size
231
-
232
- if shared_qk: k = F.normalize(k, dim = -1).type(k.dtype)
233
-
234
- seq = torch.arange(n, device = device)
235
- b_t = rearrange(seq, '(w n) -> 1 w n', w = windows, n = window_size)
236
- bq, bk, bv = map(lambda t: rearrange(t, 'b (w n) d -> b w n d', w = windows), (q, k, v))
237
-
238
- bq = bq * scale
239
- look_around_kwargs = dict(backward = look_backward, forward = look_forward, pad_value = pad_value)
240
-
241
- bk = look_around(bk, **look_around_kwargs)
242
- bv = look_around(bv, **look_around_kwargs)
243
-
244
- if exists(self.rel_pos):
245
- pos_emb, xpos_scale = self.rel_pos(bk)
246
- bq, bk = apply_rotary_pos_emb(bq, bk, pos_emb, scale = xpos_scale)
247
-
248
- bq_t = b_t
249
- bq_k = look_around(b_t, **look_around_kwargs)
250
- bq_t = rearrange(bq_t, '... i -> ... i 1')
251
- bq_k = rearrange(bq_k, '... j -> ... 1 j')
252
-
253
- pad_mask = bq_k == pad_value
254
- sim = einsum('b h i e, b h j e -> b h i j', bq, bk)
255
-
256
- if exists(attn_bias):
257
- heads = attn_bias.shape[0]
258
- assert (b % heads) == 0
259
-
260
- attn_bias = repeat(attn_bias, 'h i j -> (b h) 1 i j', b = b // heads)
261
- sim = sim + attn_bias
262
-
263
- mask_value = -torch.finfo(sim.dtype).max
264
- if shared_qk:
265
- self_mask = bq_t == bq_k
266
- sim = sim.masked_fill(self_mask, -5e4)
267
- del self_mask
268
-
269
- if causal:
270
- causal_mask = bq_t < bq_k
271
- if self.exact_windowsize: causal_mask = causal_mask | (bq_t > (bq_k + (self.window_size * self.look_backward)))
272
- sim = sim.masked_fill(causal_mask, mask_value)
273
- del causal_mask
274
-
275
- sim = sim.masked_fill(((bq_k - (self.window_size * self.look_forward)) > bq_t) | (bq_t > (bq_k + (self.window_size * self.look_backward))) | pad_mask, mask_value) if not causal and self.exact_windowsize else sim.masked_fill(pad_mask, mask_value)
276
-
277
- if exists(mask):
278
- batch = mask.shape[0]
279
- assert (b % batch) == 0
280
-
281
- h = b // mask.shape[0]
282
- if autopad: _, mask = pad_to_multiple(mask, window_size, dim = -1, value = False)
283
-
284
- mask = repeat(rearrange(look_around(rearrange(mask, '... (w n) -> (...) w n', w = windows, n = window_size), **{**look_around_kwargs, 'pad_value': False}), '... j -> ... 1 j'), 'b ... -> (b h) ...', h = h)
285
- sim = sim.masked_fill(~mask, mask_value)
286
-
287
- del mask
288
-
289
- out = rearrange(einsum('b h i j, b h j e -> b h i e', self.dropout(sim.softmax(dim = -1)), bv), 'b w n d -> b (w n) d')
290
- if autopad: out = out[:, :orig_seq_len, :]
291
-
292
- out, *_ = unpack(out, packed_shape, '* n d')
293
- return out
294
-
295
- class SinusoidalEmbeddings(nn.Module):
296
- def __init__(self, dim, scale_base = None, use_xpos = False, theta = 10000):
297
- super().__init__()
298
- inv_freq = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))
299
- self.register_buffer('inv_freq', inv_freq)
300
- self.use_xpos = use_xpos
301
- self.scale_base = scale_base
302
- assert not (use_xpos and not exists(scale_base))
303
- scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
304
- self.register_buffer('scale', scale, persistent = False)
305
-
306
- def forward(self, x):
307
- seq_len, device = x.shape[-2], x.device
308
- t = torch.arange(seq_len, device = x.device).type_as(self.inv_freq)
309
-
310
- freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
311
- freqs = torch.cat((freqs, freqs), dim = -1)
312
-
313
- if not self.use_xpos: return freqs, torch.ones(1, device = device)
314
-
315
- power = (t - (seq_len // 2)) / self.scale_base
316
- scale = self.scale ** rearrange(power, 'n -> n 1')
317
-
318
- return freqs, torch.cat((scale, scale), dim = -1)
319
-
320
- class STFT:
321
- def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5):
322
- self.target_sr = sr
323
- self.n_mels = n_mels
324
- self.n_fft = n_fft
325
- self.win_size = win_size
326
- self.hop_length = hop_length
327
- self.fmin = fmin
328
- self.fmax = fmax
329
- self.clip_val = clip_val
330
- self.mel_basis = {}
331
- self.hann_window = {}
332
-
333
- def get_mel(self, y, keyshift=0, speed=1, center=False, train=False):
334
- n_fft = self.n_fft
335
- win_size = self.win_size
336
- hop_length = self.hop_length
337
- fmax = self.fmax
338
- factor = 2 ** (keyshift / 12)
339
- win_size_new = int(np.round(win_size * factor))
340
- hop_length_new = int(np.round(hop_length * speed))
341
- mel_basis = self.mel_basis if not train else {}
342
- hann_window = self.hann_window if not train else {}
343
- mel_basis_key = str(fmax) + "_" + str(y.device)
344
-
345
- if mel_basis_key not in mel_basis: mel_basis[mel_basis_key] = torch.from_numpy(librosa_mel_fn(sr=self.target_sr, n_fft=n_fft, n_mels=self.n_mels, fmin=self.fmin, fmax=fmax)).float().to(y.device)
346
- keyshift_key = str(keyshift) + "_" + str(y.device)
347
- if keyshift_key not in hann_window: hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device)
348
-
349
- pad_left = (win_size_new - hop_length_new) // 2
350
- pad_right = max((win_size_new - hop_length_new + 1) // 2, win_size_new - y.size(-1) - pad_left)
351
- spec = torch.stft(F.pad(y.unsqueeze(1), (pad_left, pad_right), mode="reflect" if pad_right < y.size(-1) else "constant").squeeze(1), int(np.round(n_fft * factor)), hop_length=hop_length_new, win_length=win_size_new, window=hann_window[keyshift_key], center=center, pad_mode="reflect", normalized=False, onesided=True, return_complex=True)
352
- spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + (1e-9))
353
-
354
- if keyshift != 0:
355
- size = n_fft // 2 + 1
356
- resize = spec.size(1)
357
- spec = (F.pad(spec, (0, 0, 0, size - resize)) if resize < size else spec[:, :size, :]) * win_size / win_size_new
358
-
359
- return dynamic_range_compression_torch(torch.matmul(mel_basis[mel_basis_key], spec), clip_val=self.clip_val)
360
-
361
- class PCmer(nn.Module):
362
- def __init__(self, num_layers, num_heads, dim_model, dim_keys, dim_values, residual_dropout, attention_dropout):
363
- super().__init__()
364
- self.num_layers = num_layers
365
- self.num_heads = num_heads
366
- self.dim_model = dim_model
367
- self.dim_values = dim_values
368
- self.dim_keys = dim_keys
369
- self.residual_dropout = residual_dropout
370
- self.attention_dropout = attention_dropout
371
- self._layers = nn.ModuleList([_EncoderLayer(self) for _ in range(num_layers)])
372
-
373
- def forward(self, phone, mask=None):
374
- for layer in self._layers:
375
- phone = layer(phone, mask)
376
-
377
- return phone
378
-
379
- class _EncoderLayer(nn.Module):
380
- def __init__(self, parent):
381
- super().__init__()
382
- self.conformer = ConformerConvModule_LEGACY(parent.dim_model)
383
- self.norm = nn.LayerNorm(parent.dim_model)
384
- self.dropout = nn.Dropout(parent.residual_dropout)
385
- self.attn = SelfAttention(dim=parent.dim_model, heads=parent.num_heads, causal=False)
386
-
387
- def forward(self, phone, mask=None):
388
- phone = phone + (self.attn(self.norm(phone), mask=mask))
389
- return phone + (self.conformer(phone))
390
-
391
- class ConformerNaiveEncoder(nn.Module):
392
- def __init__(self, num_layers, num_heads, dim_model, use_norm = False, conv_only = False, conv_dropout = 0, atten_dropout = 0):
393
- super().__init__()
394
- self.num_layers = num_layers
395
- self.num_heads = num_heads
396
- self.dim_model = dim_model
397
- self.use_norm = use_norm
398
- self.residual_dropout = 0.1
399
- self.attention_dropout = 0.1
400
- self.encoder_layers = nn.ModuleList([CFNEncoderLayer(dim_model, num_heads, use_norm, conv_only, conv_dropout, atten_dropout) for _ in range(num_layers)])
401
-
402
- def forward(self, x, mask=None):
403
- for (_, layer) in enumerate(self.encoder_layers):
404
- x = layer(x, mask)
405
-
406
- return x
407
-
408
- class CFNaiveMelPE(nn.Module):
409
- def __init__(self, input_channels, out_dims, hidden_dims = 512, n_layers = 6, n_heads = 8, f0_max = 1975.5, f0_min = 32.70, use_fa_norm = False, conv_only = False, conv_dropout = 0, atten_dropout = 0, use_harmonic_emb = False):
410
- super().__init__()
411
- self.input_channels = input_channels
412
- self.out_dims = out_dims
413
- self.hidden_dims = hidden_dims
414
- self.n_layers = n_layers
415
- self.n_heads = n_heads
416
- self.f0_max = f0_max
417
- self.f0_min = f0_min
418
- self.use_fa_norm = use_fa_norm
419
- self.residual_dropout = 0.1
420
- self.attention_dropout = 0.1
421
- self.harmonic_emb = nn.Embedding(9, hidden_dims) if use_harmonic_emb else None
422
- self.input_stack = nn.Sequential(nn.Conv1d(input_channels, hidden_dims, 3, 1, 1), nn.GroupNorm(4, hidden_dims), nn.LeakyReLU(), nn.Conv1d(hidden_dims, hidden_dims, 3, 1, 1))
423
- self.net = ConformerNaiveEncoder(num_layers=n_layers, num_heads=n_heads, dim_model=hidden_dims, use_norm=use_fa_norm, conv_only=conv_only, conv_dropout=conv_dropout, atten_dropout=atten_dropout)
424
- self.norm = nn.LayerNorm(hidden_dims)
425
- self.output_proj = weight_norm(nn.Linear(hidden_dims, out_dims))
426
- self.cent_table_b = torch.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0], self.f0_to_cent(torch.Tensor([f0_max]))[0], out_dims).detach()
427
- self.register_buffer("cent_table", self.cent_table_b)
428
- self.gaussian_blurred_cent_mask_b = (1200 * torch.log2(torch.Tensor([self.f0_max / 10.])))[0].detach()
429
- self.register_buffer("gaussian_blurred_cent_mask", self.gaussian_blurred_cent_mask_b)
430
-
431
- def forward(self, x, _h_emb=None):
432
- x = self.input_stack(x.transpose(-1, -2)).transpose(-1, -2)
433
- if self.harmonic_emb is not None: x = x + self.harmonic_emb(torch.LongTensor([0]).to(x.device)) if _h_emb is None else x + self.harmonic_emb(torch.LongTensor([int(_h_emb)]).to(x.device))
434
- return torch.sigmoid(self.output_proj(self.norm(self.net(x))))
435
-
436
- @torch.no_grad()
437
- def latent2cents_decoder(self, y, threshold = 0.05, mask = True):
438
- B, N, _ = y.size()
439
- ci = self.cent_table[None, None, :].expand(B, N, -1)
440
- rtn = torch.sum(ci * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True)
441
-
442
- if mask:
443
- confident = torch.max(y, dim=-1, keepdim=True)[0]
444
- confident_mask = torch.ones_like(confident)
445
- confident_mask[confident <= threshold] = float("-INF")
446
- rtn = rtn * confident_mask
447
-
448
- return rtn
449
-
450
- @torch.no_grad()
451
- def latent2cents_local_decoder(self, y, threshold = 0.05, mask = True):
452
- B, N, _ = y.size()
453
- ci = self.cent_table[None, None, :].expand(B, N, -1)
454
- confident, max_index = torch.max(y, dim=-1, keepdim=True)
455
-
456
- local_argmax_index = torch.arange(0, 9).to(max_index.device) + (max_index - 4)
457
- local_argmax_index[local_argmax_index < 0] = 0
458
- local_argmax_index[local_argmax_index >= self.out_dims] = self.out_dims - 1
459
-
460
- y_l = torch.gather(y, -1, local_argmax_index)
461
- rtn = torch.sum(torch.gather(ci, -1, local_argmax_index) * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True)
462
-
463
- if mask:
464
- confident_mask = torch.ones_like(confident)
465
- confident_mask[confident <= threshold] = float("-INF")
466
- rtn = rtn * confident_mask
467
-
468
- return rtn
469
-
470
- @torch.no_grad()
471
- def infer(self, mel, decoder = "local_argmax", threshold = 0.05):
472
- latent = self.forward(mel)
473
- if decoder == "argmax": cents = self.latent2cents_local_decoder
474
- elif decoder == "local_argmax": cents = self.latent2cents_local_decoder
475
-
476
- return self.cent_to_f0(cents(latent, threshold=threshold))
477
-
478
- @torch.no_grad()
479
- def cent_to_f0(self, cent: torch.Tensor) -> torch.Tensor:
480
- return 10 * 2 ** (cent / 1200)
481
-
482
- @torch.no_grad()
483
- def f0_to_cent(self, f0):
484
- return 1200 * torch.log2(f0 / 10)
485
-
486
- class CFNEncoderLayer(nn.Module):
487
- def __init__(self, dim_model, num_heads = 8, use_norm = False, conv_only = False, conv_dropout = 0, atten_dropout = 0):
488
- super().__init__()
489
- self.conformer = nn.Sequential(ConformerConvModule(dim_model), nn.Dropout(conv_dropout)) if conv_dropout > 0 else ConformerConvModule(dim_model)
490
- self.norm = nn.LayerNorm(dim_model)
491
- self.dropout = nn.Dropout(0.1)
492
- self.attn = SelfAttention(dim=dim_model, heads=num_heads, causal=False, use_norm=use_norm, dropout=atten_dropout) if not conv_only else None
493
-
494
- def forward(self, x, mask=None):
495
- if self.attn is not None: x = x + (self.attn(self.norm(x), mask=mask))
496
- return x + (self.conformer(x))
497
-
498
- class Swish(nn.Module):
499
- def forward(self, x):
500
- return x * x.sigmoid()
501
-
502
- class Transpose(nn.Module):
503
- def __init__(self, dims):
504
- super().__init__()
505
- assert len(dims) == 2, "dims == 2"
506
- self.dims = dims
507
-
508
- def forward(self, x):
509
- return x.transpose(*self.dims)
510
-
511
- class GLU(nn.Module):
512
- def __init__(self, dim):
513
- super().__init__()
514
- self.dim = dim
515
-
516
- def forward(self, x):
517
- out, gate = x.chunk(2, dim=self.dim)
518
- return out * gate.sigmoid()
519
-
520
- class DepthWiseConv1d_LEGACY(nn.Module):
521
- def __init__(self, chan_in, chan_out, kernel_size, padding):
522
- super().__init__()
523
- self.padding = padding
524
- self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in)
525
-
526
- def forward(self, x):
527
- return self.conv(F.pad(x, self.padding))
528
-
529
- class DepthWiseConv1d(nn.Module):
530
- def __init__(self, chan_in, chan_out, kernel_size, padding, groups):
531
- super().__init__()
532
- self.conv = nn.Conv1d(chan_in, chan_out, kernel_size=kernel_size, padding=padding, groups=groups)
533
-
534
- def forward(self, x):
535
- return self.conv(x)
536
-
537
- class ConformerConvModule_LEGACY(nn.Module):
538
- def __init__(self, dim, causal=False, expansion_factor=2, kernel_size=31, dropout=0.0):
539
- super().__init__()
540
- inner_dim = dim * expansion_factor
541
- self.net = nn.Sequential(nn.LayerNorm(dim), Transpose((1, 2)), nn.Conv1d(dim, inner_dim * 2, 1), GLU(dim=1), DepthWiseConv1d_LEGACY(inner_dim, inner_dim, kernel_size=kernel_size, padding=(calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0))), Swish(), nn.Conv1d(inner_dim, dim, 1), Transpose((1, 2)), nn.Dropout(dropout))
542
-
543
- def forward(self, x):
544
- return self.net(x)
545
-
546
- class ConformerConvModule(nn.Module):
547
- def __init__(self, dim, expansion_factor=2, kernel_size=31, dropout=0):
548
- super().__init__()
549
- inner_dim = dim * expansion_factor
550
- self.net = nn.Sequential(nn.LayerNorm(dim), Transpose((1, 2)), nn.Conv1d(dim, inner_dim * 2, 1), nn.GLU(dim=1), DepthWiseConv1d(inner_dim, inner_dim, kernel_size=kernel_size, padding=calc_same_padding(kernel_size)[0], groups=inner_dim), nn.SiLU(), nn.Conv1d(inner_dim, dim, 1), Transpose((1, 2)), nn.Dropout(dropout))
551
-
552
- def forward(self, x):
553
- return self.net(x)
554
-
555
- class FastAttention(nn.Module):
556
- def __init__(self, dim_heads, nb_features=None, ortho_scaling=0, causal=False, generalized_attention=False, kernel_fn=nn.ReLU(), qr_uniform_q=False, no_projection=False):
557
- super().__init__()
558
- nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))
559
- self.dim_heads = dim_heads
560
- self.nb_features = nb_features
561
- self.ortho_scaling = ortho_scaling
562
- self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows=self.nb_features, nb_columns=dim_heads, scaling=ortho_scaling, qr_uniform_q=qr_uniform_q)
563
- projection_matrix = self.create_projection()
564
- self.register_buffer("projection_matrix", projection_matrix)
565
- self.generalized_attention = generalized_attention
566
- self.kernel_fn = kernel_fn
567
- self.no_projection = no_projection
568
- self.causal = causal
569
-
570
- @torch.no_grad()
571
- def redraw_projection_matrix(self):
572
- projections = self.create_projection()
573
- self.projection_matrix.copy_(projections)
574
- del projections
575
-
576
- def forward(self, q, k, v):
577
- if self.no_projection: q, k = q.softmax(dim=-1), (torch.exp(k) if self.causal else k.softmax(dim=-2))
578
- else:
579
- create_kernel = partial(softmax_kernel, projection_matrix=self.projection_matrix, device=q.device)
580
- q, k = create_kernel(q, is_query=True), create_kernel(k, is_query=False)
581
-
582
- attn_fn = linear_attention if not self.causal else self.causal_linear_fn
583
- return attn_fn(q, k, None) if v is None else attn_fn(q, k, v)
584
-
585
- class SelfAttention(nn.Module):
586
- def __init__(self, dim, causal=False, heads=8, dim_head=64, local_heads=0, local_window_size=256, nb_features=None, feature_redraw_interval=1000, generalized_attention=False, kernel_fn=nn.ReLU(), qr_uniform_q=False, dropout=0.0, no_projection=False):
587
- super().__init__()
588
- assert dim % heads == 0
589
- dim_head = default(dim_head, dim // heads)
590
- inner_dim = dim_head * heads
591
- self.fast_attention = FastAttention(dim_head, nb_features, causal=causal, generalized_attention=generalized_attention, kernel_fn=kernel_fn, qr_uniform_q=qr_uniform_q, no_projection=no_projection)
592
- self.heads = heads
593
- self.global_heads = heads - local_heads
594
- self.local_attn = (LocalAttention(window_size=local_window_size, causal=causal, autopad=True, dropout=dropout, look_forward=int(not causal), rel_pos_emb_config=(dim_head, local_heads)) if local_heads > 0 else None)
595
- self.to_q = nn.Linear(dim, inner_dim)
596
- self.to_k = nn.Linear(dim, inner_dim)
597
- self.to_v = nn.Linear(dim, inner_dim)
598
- self.to_out = nn.Linear(inner_dim, dim)
599
- self.dropout = nn.Dropout(dropout)
600
-
601
- @torch.no_grad()
602
- def redraw_projection_matrix(self):
603
- self.fast_attention.redraw_projection_matrix()
604
-
605
- def forward(self, x, context=None, mask=None, context_mask=None, name=None, inference=False, **kwargs):
606
- _, _, _, h, gh = *x.shape, self.heads, self.global_heads
607
- cross_attend = exists(context)
608
- context = default(context, x)
609
- context_mask = default(context_mask, mask) if not cross_attend else context_mask
610
-
611
- q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (self.to_q(x), self.to_k(context), self.to_v(context)))
612
- (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v))
613
-
614
- attn_outs = []
615
-
616
- if not empty(q):
617
- if exists(context_mask): v.masked_fill_(~context_mask[:, None, :, None], 0.0)
618
- if cross_attend: pass
619
- else: out = self.fast_attention(q, k, v)
620
-
621
- attn_outs.append(out)
622
-
623
- if not empty(lq):
624
- assert (not cross_attend), "not cross_attend"
625
-
626
- out = self.local_attn(lq, lk, lv, input_mask=mask)
627
- attn_outs.append(out)
628
-
629
- return self.dropout(self.to_out(rearrange(torch.cat(attn_outs, dim=1), "b h n d -> b n (h d)")))
630
-
631
- class HannWindow(torch.nn.Module):
632
- def __init__(self, win_size):
633
- super().__init__()
634
- self.register_buffer('window', torch.hann_window(win_size), persistent=False)
635
-
636
- def forward(self):
637
- return self.window
638
-
639
- class FCPE_LEGACY(nn.Module):
640
- def __init__(self, input_channel=128, out_dims=360, n_layers=12, n_chans=512, loss_mse_scale=10, loss_l2_regularization=False, loss_l2_regularization_scale=1, loss_grad1_mse=False, loss_grad1_mse_scale=1, f0_max=1975.5, f0_min=32.70, confidence=False, threshold=0.05, use_input_conv=True):
641
- super().__init__()
642
- self.loss_mse_scale = loss_mse_scale if (loss_mse_scale is not None) else 10
643
- self.loss_l2_regularization = (loss_l2_regularization if (loss_l2_regularization is not None) else False)
644
- self.loss_l2_regularization_scale = (loss_l2_regularization_scale if (loss_l2_regularization_scale is not None) else 1)
645
- self.loss_grad1_mse = loss_grad1_mse if (loss_grad1_mse is not None) else False
646
- self.loss_grad1_mse_scale = (loss_grad1_mse_scale if (loss_grad1_mse_scale is not None) else 1)
647
- self.f0_max = f0_max if (f0_max is not None) else 1975.5
648
- self.f0_min = f0_min if (f0_min is not None) else 32.70
649
- self.confidence = confidence if (confidence is not None) else False
650
- self.threshold = threshold if (threshold is not None) else 0.05
651
- self.use_input_conv = use_input_conv if (use_input_conv is not None) else True
652
- self.cent_table_b = torch.Tensor(np.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0], self.f0_to_cent(torch.Tensor([f0_max]))[0], out_dims))
653
- self.register_buffer("cent_table", self.cent_table_b)
654
- self.stack = nn.Sequential(nn.Conv1d(input_channel, n_chans, 3, 1, 1), nn.GroupNorm(4, n_chans), nn.LeakyReLU(), nn.Conv1d(n_chans, n_chans, 3, 1, 1))
655
- self.decoder = PCmer(num_layers=n_layers, num_heads=8, dim_model=n_chans, dim_keys=n_chans, dim_values=n_chans, residual_dropout=0.1, attention_dropout=0.1)
656
- self.norm = nn.LayerNorm(n_chans)
657
- self.n_out = out_dims
658
- self.dense_out = weight_norm(nn.Linear(n_chans, self.n_out))
659
-
660
- def forward(self, mel, infer=True, gt_f0=None, return_hz_f0=False, cdecoder="local_argmax", output_interp_target_length=None):
661
- if cdecoder == "argmax": self.cdecoder = self.cents_decoder
662
- elif cdecoder == "local_argmax": self.cdecoder = self.cents_local_decoder
663
-
664
- x = torch.sigmoid(self.dense_out(self.norm(self.decoder((self.stack(mel.transpose(1, 2)).transpose(1, 2) if self.use_input_conv else mel)))))
665
-
666
- if not infer:
667
- loss_all = self.loss_mse_scale * F.binary_cross_entropy(x, self.gaussian_blurred_cent(self.f0_to_cent(gt_f0)))
668
- if self.loss_l2_regularization: loss_all = loss_all + l2_regularization(model=self, l2_alpha=self.loss_l2_regularization_scale)
669
- x = loss_all
670
- else:
671
- x = self.cent_to_f0(self.cdecoder(x))
672
- x = (1 + x / 700).log() if not return_hz_f0 else x
673
-
674
- if output_interp_target_length is not None:
675
- x = F.interpolate(torch.where(x == 0, float("nan"), x).transpose(1, 2), size=int(output_interp_target_length), mode="linear").transpose(1, 2)
676
- x = torch.where(x.isnan(), float(0.0), x)
677
-
678
- return x
679
-
680
- def cents_decoder(self, y, mask=True):
681
- B, N, _ = y.size()
682
- rtn = torch.sum(self.cent_table[None, None, :].expand(B, N, -1) * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True)
683
-
684
- if mask:
685
- confident = torch.max(y, dim=-1, keepdim=True)[0]
686
- confident_mask = torch.ones_like(confident)
687
- confident_mask[confident <= self.threshold] = float("-INF")
688
- rtn = rtn * confident_mask
689
-
690
- return (rtn, confident) if self.confidence else rtn
691
-
692
- def cents_local_decoder(self, y, mask=True):
693
- B, N, _ = y.size()
694
-
695
- confident, max_index = torch.max(y, dim=-1, keepdim=True)
696
- local_argmax_index = torch.clamp(torch.arange(0, 9).to(max_index.device) + (max_index - 4), 0, self.n_out - 1)
697
- y_l = torch.gather(y, -1, local_argmax_index)
698
- rtn = torch.sum(torch.gather(self.cent_table[None, None, :].expand(B, N, -1), -1, local_argmax_index) * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True)
699
-
700
- if mask:
701
- confident_mask = torch.ones_like(confident)
702
- confident_mask[confident <= self.threshold] = float("-INF")
703
- rtn = rtn * confident_mask
704
-
705
- return (rtn, confident) if self.confidence else rtn
706
-
707
- def cent_to_f0(self, cent):
708
- return 10.0 * 2 ** (cent / 1200.0)
709
-
710
- def f0_to_cent(self, f0):
711
- return 1200.0 * torch.log2(f0 / 10.0)
712
-
713
- def gaussian_blurred_cent(self, cents):
714
- B, N, _ = cents.size()
715
- return torch.exp(-torch.square(self.cent_table[None, None, :].expand(B, N, -1) - cents) / 1250) * (cents > 0.1) & (cents < (1200.0 * np.log2(self.f0_max / 10.0))).float()
716
-
717
- class InferCFNaiveMelPE(torch.nn.Module):
718
- def __init__(self, args, state_dict):
719
- super().__init__()
720
- self.wav2mel = spawn_wav2mel(args, device="cpu")
721
- self.model = CFNaiveMelPE(input_channels=catch_none_args_must(args.mel.num_mels, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.mel.num_mels is None"), out_dims=catch_none_args_must(args.model.out_dims, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.out_dims is None"), hidden_dims=catch_none_args_must(args.model.hidden_dims, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.hidden_dims is None"), n_layers=catch_none_args_must(args.model.n_layers, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.n_layers is None"), n_heads=catch_none_args_must(args.model.n_heads, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.n_heads is None"), f0_max=catch_none_args_must(args.model.f0_max, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.f0_max is None"), f0_min=catch_none_args_must(args.model.f0_min, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.f0_min is None"), use_fa_norm=catch_none_args_must(args.model.use_fa_norm, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.use_fa_norm is None"), conv_only=catch_none_args_opti(args.model.conv_only, default=False, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.conv_only is None"), conv_dropout=catch_none_args_opti(args.model.conv_dropout, default=0.0, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.conv_dropout is None"), atten_dropout=catch_none_args_opti(args.model.atten_dropout, default=0.0, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.atten_dropout is None"), use_harmonic_emb=catch_none_args_opti(args.model.use_harmonic_emb, default=False, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.use_harmonic_emb is None"))
722
- self.model.load_state_dict(state_dict)
723
- self.model.eval()
724
- self.args_dict = dict(args)
725
- self.register_buffer("tensor_device_marker", torch.tensor(1.0).float(), persistent=False)
726
-
727
- def forward(self, wav, sr, decoder_mode = "local_argmax", threshold = 0.006, key_shifts = [0]):
728
- with torch.no_grad():
729
- mels = rearrange(torch.stack([self.wav2mel(wav.to(self.tensor_device_marker.device), sr, keyshift=keyshift) for keyshift in key_shifts], -1), "B T C K -> (B K) T C")
730
- f0s = rearrange(self.model.infer(mels, decoder=decoder_mode, threshold=threshold), "(B K) T 1 -> B T (K 1)", K=len(key_shifts))
731
-
732
- return f0s
733
-
734
- def infer(self, wav, sr, decoder_mode = "local_argmax", threshold = 0.006, f0_min = None, f0_max = None, interp_uv = False, output_interp_target_length = None, return_uv = False, test_time_augmentation = False, tta_uv_penalty = 12.0, tta_key_shifts = [0, -12, 12], tta_use_origin_uv=False):
735
- if test_time_augmentation:
736
- assert len(tta_key_shifts) > 0
737
- flag = 0
738
- if tta_use_origin_uv:
739
- if 0 not in tta_key_shifts:
740
- flag = 1
741
- tta_key_shifts.append(0)
742
-
743
- tta_key_shifts.sort(key=lambda x: (x if x >= 0 else -x / 2))
744
- f0s = self.__call__(wav, sr, decoder_mode, threshold, tta_key_shifts)
745
- f0 = ensemble_f0(f0s[:, :, flag:], tta_key_shifts[flag:], tta_uv_penalty)
746
- f0_for_uv = f0s[:, :, [0]] if tta_use_origin_uv else f0
747
- else:
748
- f0 = self.__call__(wav, sr, decoder_mode, threshold)
749
- f0_for_uv = f0
750
-
751
- if f0_min is None: f0_min = self.args_dict["model"]["f0_min"]
752
- uv = (f0_for_uv < f0_min).type(f0_for_uv.dtype)
753
- f0 = f0 * (1 - uv)
754
-
755
- if interp_uv: f0 = batch_interp_with_replacement_detach(uv.squeeze(-1).bool(), f0.squeeze(-1)).unsqueeze(-1)
756
- if f0_max is not None: f0[f0 > f0_max] = f0_max
757
- if output_interp_target_length is not None:
758
- f0 = F.interpolate(torch.where(f0 == 0, float("nan"), f0).transpose(1, 2), size=int(output_interp_target_length), mode="linear").transpose(1, 2)
759
- f0 = torch.where(f0.isnan(), float(0.0), f0)
760
-
761
- if return_uv: return f0, F.interpolate(uv.transpose(1, 2), size=int(output_interp_target_length), mode="nearest").transpose(1, 2)
762
- else: return f0
763
-
764
- class FCPEInfer_LEGACY:
765
- def __init__(self, model_path, device=None, dtype=torch.float32, providers=None, onnx=False, f0_min=50, f0_max=1100):
766
- if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
767
- self.device = device
768
- self.dtype = dtype
769
- self.onnx = onnx
770
- self.f0_min = f0_min
771
- self.f0_max = f0_max
772
-
773
- if self.onnx:
774
- sess_options = ort.SessionOptions()
775
- sess_options.log_severity_level = 3
776
- self.model = ort.InferenceSession(decrypt_model(model_path), sess_options=sess_options, providers=providers)
777
- else:
778
- ckpt = torch.load(model_path, map_location=torch.device(self.device))
779
- self.args = DotDict(ckpt["config"])
780
- model = FCPE_LEGACY(input_channel=self.args.model.input_channel, out_dims=self.args.model.out_dims, n_layers=self.args.model.n_layers, n_chans=self.args.model.n_chans, loss_mse_scale=self.args.loss.loss_mse_scale, loss_l2_regularization=self.args.loss.loss_l2_regularization, loss_l2_regularization_scale=self.args.loss.loss_l2_regularization_scale, loss_grad1_mse=self.args.loss.loss_grad1_mse, loss_grad1_mse_scale=self.args.loss.loss_grad1_mse_scale, f0_max=self.f0_max, f0_min=self.f0_min, confidence=self.args.model.confidence)
781
- model.to(self.device).to(self.dtype)
782
- model.load_state_dict(ckpt["model"])
783
- model.eval()
784
- self.model = model
785
-
786
- @torch.no_grad()
787
- def __call__(self, audio, sr, threshold=0.05, p_len=None):
788
- if not self.onnx: self.model.threshold = threshold
789
- self.wav2mel = Wav2Mel(device=self.device, dtype=self.dtype)
790
-
791
- return (torch.as_tensor(self.model.run([self.model.get_outputs()[0].name], {self.model.get_inputs()[0].name: self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype).detach().cpu().numpy(), self.model.get_inputs()[1].name: np.array(threshold, dtype=np.float32)})[0], dtype=self.dtype, device=self.device) if self.onnx else self.model(mel=self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype), infer=True, return_hz_f0=True, output_interp_target_length=p_len))
792
-
793
- class FCPEInfer:
794
- def __init__(self, model_path, device=None, dtype=torch.float32, providers=None, onnx=False, f0_min=50, f0_max=1100):
795
- if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
796
- self.device = device
797
- self.dtype = dtype
798
- self.onnx = onnx
799
- self.f0_min = f0_min
800
- self.f0_max = f0_max
801
-
802
- if self.onnx:
803
- sess_options = ort.SessionOptions()
804
- sess_options.log_severity_level = 3
805
- self.model = ort.InferenceSession(decrypt_model(model_path), sess_options=sess_options, providers=providers)
806
- else:
807
- ckpt = torch.load(model_path, map_location=torch.device(device))
808
- ckpt["config_dict"]["model"]["conv_dropout"] = ckpt["config_dict"]["model"]["atten_dropout"] = 0.0
809
- self.args = DotDict(ckpt["config_dict"])
810
- model = InferCFNaiveMelPE(self.args, ckpt["model"])
811
- model = model.to(device).to(self.dtype)
812
- model.eval()
813
- self.model = model
814
-
815
- @torch.no_grad()
816
- def __call__(self, audio, sr, threshold=0.05, p_len=None):
817
- if self.onnx: self.wav2mel = Wav2Mel(device=self.device, dtype=self.dtype)
818
- return (torch.as_tensor(self.model.run([self.model.get_outputs()[0].name], {self.model.get_inputs()[0].name: self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype).detach().cpu().numpy(), self.model.get_inputs()[1].name: np.array(threshold, dtype=np.float32)})[0], dtype=self.dtype, device=self.device) if self.onnx else self.model.infer(audio[None, :], sr, threshold=threshold, f0_min=self.f0_min, f0_max=self.f0_max, output_interp_target_length=p_len))
819
-
820
- class MelModule(torch.nn.Module):
821
- def __init__(self, sr, n_mels, n_fft, win_size, hop_length, fmin = None, fmax = None, clip_val = 1e-5, out_stft = False):
822
- super().__init__()
823
- if fmin is None: fmin = 0
824
- if fmax is None: fmax = sr / 2
825
- self.target_sr = sr
826
- self.n_mels = n_mels
827
- self.n_fft = n_fft
828
- self.win_size = win_size
829
- self.hop_length = hop_length
830
- self.fmin = fmin
831
- self.fmax = fmax
832
- self.clip_val = clip_val
833
- self.register_buffer('mel_basis', torch.tensor(librosa_mel_fn(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)).float(), persistent=False)
834
- self.hann_window = torch.nn.ModuleDict()
835
- self.out_stft = out_stft
836
-
837
- @torch.no_grad()
838
- def __call__(self, y, key_shift = 0, speed = 1, center = False, no_cache_window = False):
839
- n_fft = self.n_fft
840
- win_size = self.win_size
841
- hop_length = self.hop_length
842
- clip_val = self.clip_val
843
- factor = 2 ** (key_shift / 12)
844
- n_fft_new = int(np.round(n_fft * factor))
845
- win_size_new = int(np.round(win_size * factor))
846
- hop_length_new = int(np.round(hop_length * speed))
847
-
848
- y = y.squeeze(-1)
849
- if torch.min(y) < -1: print('[error with torchfcpe.mel_extractor.MelModule] min ', torch.min(y))
850
- if torch.max(y) > 1: print('[error with torchfcpe.mel_extractor.MelModule] max ', torch.max(y))
851
-
852
- key_shift_key = str(key_shift)
853
- if not no_cache_window:
854
- if key_shift_key in self.hann_window: hann_window = self.hann_window[key_shift_key]
855
- else:
856
- hann_window = HannWindow(win_size_new).to(self.mel_basis.device)
857
- self.hann_window[key_shift_key] = hann_window
858
-
859
- hann_window_tensor = hann_window()
860
- else: hann_window_tensor = torch.hann_window(win_size_new).to(self.mel_basis.device)
861
-
862
- pad_left = (win_size_new - hop_length_new) // 2
863
- pad_right = max((win_size_new - hop_length_new + 1) // 2, win_size_new - y.size(-1) - pad_left)
864
- mode = 'reflect' if pad_right < y.size(-1) else 'constant'
865
- spec = torch.stft(F.pad(y.unsqueeze(1), (pad_left, pad_right), mode=mode).squeeze(1), n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=hann_window_tensor, center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
866
- spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-9)
867
-
868
- if key_shift != 0:
869
- size = n_fft // 2 + 1
870
- resize = spec.size(1)
871
-
872
- if resize < size: spec = F.pad(spec, (0, 0, 0, size - resize))
873
- spec = spec[:, :size, :] * win_size / win_size_new
874
-
875
- spec = spec[:, :512, :] if self.out_stft else torch.matmul(self.mel_basis, spec)
876
- return dynamic_range_compression_torch(spec, clip_val=clip_val).transpose(-1, -2)
877
-
878
- class Wav2MelModule(torch.nn.Module):
879
- def __init__(self, sr, n_mels, n_fft, win_size, hop_length, fmin = None, fmax = None, clip_val = 1e-5, mel_type="default"):
880
- super().__init__()
881
- if fmin is None: fmin = 0
882
- if fmax is None: fmax = sr / 2
883
- self.sampling_rate = sr
884
- self.n_mels = n_mels
885
- self.n_fft = n_fft
886
- self.win_size = win_size
887
- self.hop_size = hop_length
888
- self.fmin = fmin
889
- self.fmax = fmax
890
- self.clip_val = clip_val
891
- self.register_buffer('tensor_device_marker', torch.tensor(1.0).float(), persistent=False)
892
- self.resample_kernel = torch.nn.ModuleDict()
893
- if mel_type == "default": self.mel_extractor = MelModule(sr, n_mels, n_fft, win_size, hop_length, fmin, fmax, clip_val, out_stft=False)
894
- elif mel_type == "stft": self.mel_extractor = MelModule(sr, n_mels, n_fft, win_size, hop_length, fmin, fmax, clip_val, out_stft=True)
895
- self.mel_type = mel_type
896
-
897
- @torch.no_grad()
898
- def __call__(self, audio, sample_rate, keyshift = 0, no_cache_window = False):
899
- if sample_rate == self.sampling_rate: audio_res = audio
900
- else:
901
- key_str = str(sample_rate)
902
- if key_str not in self.resample_kernel:
903
- if len(self.resample_kernel) > 8: self.resample_kernel.clear()
904
- self.resample_kernel[key_str] = Resample(sample_rate, self.sampling_rate, lowpass_filter_width=128).to(self.tensor_device_marker.device)
905
-
906
- audio_res = self.resample_kernel[key_str](audio.squeeze(-1)).unsqueeze(-1)
907
-
908
- mel = self.mel_extractor(audio_res, keyshift, no_cache_window=no_cache_window)
909
- n_frames = int(audio.shape[1] // self.hop_size) + 1
910
- if n_frames > int(mel.shape[1]): mel = torch.cat((mel, mel[:, -1:, :]), 1)
911
- if n_frames < int(mel.shape[1]): mel = mel[:, :n_frames, :]
912
-
913
- return mel
914
-
915
- class Wav2Mel:
916
- def __init__(self, device=None, dtype=torch.float32):
917
- self.sample_rate = 16000
918
- self.hop_size = 160
919
- if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
920
- self.device = device
921
- self.dtype = dtype
922
- self.stft = STFT(16000, 128, 1024, 1024, 160, 0, 8000)
923
- self.resample_kernel = {}
924
-
925
- def extract_nvstft(self, audio, keyshift=0, train=False):
926
- return self.stft.get_mel(audio, keyshift=keyshift, train=train).transpose(1, 2)
927
-
928
- def extract_mel(self, audio, sample_rate, keyshift=0, train=False):
929
- audio = audio.to(self.dtype).to(self.device)
930
- if sample_rate == self.sample_rate: audio_res = audio
931
- else:
932
- key_str = str(sample_rate)
933
- if key_str not in self.resample_kernel: self.resample_kernel[key_str] = Resample(sample_rate, self.sample_rate, lowpass_filter_width=128)
934
- self.resample_kernel[key_str] = (self.resample_kernel[key_str].to(self.dtype).to(self.device))
935
- audio_res = self.resample_kernel[key_str](audio)
936
-
937
- mel = self.extract_nvstft(audio_res, keyshift=keyshift, train=train)
938
- n_frames = int(audio.shape[1] // self.hop_size) + 1
939
- mel = (torch.cat((mel, mel[:, -1:, :]), 1) if n_frames > int(mel.shape[1]) else mel)
940
- return mel[:, :n_frames, :] if n_frames < int(mel.shape[1]) else mel
941
-
942
- def __call__(self, audio, sample_rate, keyshift=0, train=False):
943
- return self.extract_mel(audio, sample_rate, keyshift=keyshift, train=train)
944
-
945
- class DotDict(dict):
946
- def __getattr__(*args):
947
- val = dict.get(*args)
948
- return DotDict(val) if type(val) is dict else val
949
-
950
- __setattr__ = dict.__setitem__
951
- __delattr__ = dict.__delitem__
952
-
953
- class FCPE:
954
- def __init__(self, model_path, hop_length=512, f0_min=50, f0_max=1100, dtype=torch.float32, device=None, sample_rate=16000, threshold=0.05, providers=None, onnx=False, legacy=False):
955
- self.model = FCPEInfer_LEGACY if legacy else FCPEInfer
956
- self.fcpe = self.model(model_path, device=device, dtype=dtype, providers=providers, onnx=onnx, f0_min=f0_min, f0_max=f0_max)
957
- self.hop_length = hop_length
958
- self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
959
- self.threshold = threshold
960
- self.sample_rate = sample_rate
961
- self.dtype = dtype
962
- self.legacy = legacy
963
-
964
- def repeat_expand(self, content, target_len, mode = "nearest"):
965
- ndim = content.ndim
966
- content = (content[None, None] if ndim == 1 else content[None] if ndim == 2 else content)
967
-
968
- assert content.ndim == 3
969
- is_np = isinstance(content, np.ndarray)
970
-
971
- results = F.interpolate(torch.from_numpy(content) if is_np else content, size=target_len, mode=mode)
972
- results = results.numpy() if is_np else results
973
- return results[0, 0] if ndim == 1 else results[0] if ndim == 2 else results
974
-
975
- def post_process(self, x, sample_rate, f0, pad_to):
976
- f0 = (torch.from_numpy(f0).float().to(x.device) if isinstance(f0, np.ndarray) else f0)
977
- f0 = self.repeat_expand(f0, pad_to) if pad_to is not None else f0
978
-
979
- vuv_vector = torch.zeros_like(f0)
980
- vuv_vector[f0 > 0.0] = 1.0
981
- vuv_vector[f0 <= 0.0] = 0.0
982
-
983
- nzindex = torch.nonzero(f0).squeeze()
984
- f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy()
985
- vuv_vector = F.interpolate(vuv_vector[None, None, :], size=pad_to)[0][0]
986
-
987
- if f0.shape[0] <= 0: return np.zeros(pad_to), vuv_vector.cpu().numpy()
988
- if f0.shape[0] == 1: return np.ones(pad_to) * f0[0], vuv_vector.cpu().numpy()
989
-
990
- return np.interp(np.arange(pad_to) * self.hop_length / sample_rate, self.hop_length / sample_rate * nzindex.cpu().numpy(), f0, left=f0[0], right=f0[-1]), vuv_vector.cpu().numpy()
991
-
992
- def compute_f0(self, wav, p_len=None):
993
- x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
994
- p_len = x.shape[0] // self.hop_length if p_len is None else p_len
995
-
996
- f0 = self.fcpe(x, sr=self.sample_rate, threshold=self.threshold, p_len=p_len)
997
- f0 = f0[:] if f0.dim() == 1 else f0[0, :, 0]
998
-
999
- if torch.all(f0 == 0): return f0.cpu().numpy() if p_len is None else np.zeros(p_len), (f0.cpu().numpy() if p_len is None else np.zeros(p_len))
1000
- return self.post_process(x, self.sample_rate, f0, p_len)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/predictors/RMVPE.py DELETED
@@ -1,260 +0,0 @@
1
- import torch
2
-
3
- import numpy as np
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
-
7
- from librosa.filters import mel
8
-
9
- N_MELS, N_CLASS = 128, 360
10
-
11
- class ConvBlockRes(nn.Module):
12
- def __init__(self, in_channels, out_channels, momentum=0.01):
13
- super(ConvBlockRes, self).__init__()
14
- self.conv = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU(), nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU())
15
-
16
- if in_channels != out_channels:
17
- self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
18
- self.is_shortcut = True
19
- else: self.is_shortcut = False
20
-
21
- def forward(self, x):
22
- return self.conv(x) + self.shortcut(x) if self.is_shortcut else self.conv(x) + x
23
-
24
- class ResEncoderBlock(nn.Module):
25
- def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01):
26
- super(ResEncoderBlock, self).__init__()
27
- self.n_blocks = n_blocks
28
- self.conv = nn.ModuleList()
29
- self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
30
-
31
- for _ in range(n_blocks - 1):
32
- self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
33
-
34
- self.kernel_size = kernel_size
35
- if self.kernel_size is not None: self.pool = nn.AvgPool2d(kernel_size=kernel_size)
36
-
37
- def forward(self, x):
38
- for i in range(self.n_blocks):
39
- x = self.conv[i](x)
40
-
41
- if self.kernel_size is not None: return x, self.pool(x)
42
- else: return x
43
-
44
- class Encoder(nn.Module):
45
- def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01):
46
- super(Encoder, self).__init__()
47
- self.n_encoders = n_encoders
48
- self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
49
- self.layers = nn.ModuleList()
50
- self.latent_channels = []
51
-
52
- for _ in range(self.n_encoders):
53
- self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum))
54
- self.latent_channels.append([out_channels, in_size])
55
- in_channels = out_channels
56
- out_channels *= 2
57
- in_size //= 2
58
-
59
- self.out_size = in_size
60
- self.out_channel = out_channels
61
-
62
- def forward(self, x):
63
- concat_tensors = []
64
- x = self.bn(x)
65
-
66
- for i in range(self.n_encoders):
67
- t, x = self.layers[i](x)
68
- concat_tensors.append(t)
69
-
70
- return x, concat_tensors
71
-
72
- class Intermediate(nn.Module):
73
- def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
74
- super(Intermediate, self).__init__()
75
- self.n_inters = n_inters
76
- self.layers = nn.ModuleList()
77
- self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum))
78
-
79
- for _ in range(self.n_inters - 1):
80
- self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum))
81
-
82
- def forward(self, x):
83
- for i in range(self.n_inters):
84
- x = self.layers[i](x)
85
-
86
- return x
87
-
88
- class ResDecoderBlock(nn.Module):
89
- def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
90
- super(ResDecoderBlock, self).__init__()
91
- out_padding = (0, 1) if stride == (1, 2) else (1, 1)
92
- self.n_blocks = n_blocks
93
- self.conv1 = nn.Sequential(nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=stride, padding=(1, 1), output_padding=out_padding, bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU())
94
- self.conv2 = nn.ModuleList()
95
- self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
96
-
97
- for _ in range(n_blocks - 1):
98
- self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
99
-
100
- def forward(self, x, concat_tensor):
101
- x = torch.cat((self.conv1(x), concat_tensor), dim=1)
102
-
103
- for i in range(self.n_blocks):
104
- x = self.conv2[i](x)
105
-
106
- return x
107
-
108
- class Decoder(nn.Module):
109
- def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
110
- super(Decoder, self).__init__()
111
- self.layers = nn.ModuleList()
112
- self.n_decoders = n_decoders
113
-
114
- for _ in range(self.n_decoders):
115
- out_channels = in_channels // 2
116
- self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum))
117
- in_channels = out_channels
118
-
119
- def forward(self, x, concat_tensors):
120
- for i in range(self.n_decoders):
121
- x = self.layers[i](x, concat_tensors[-1 - i])
122
-
123
- return x
124
-
125
- class DeepUnet(nn.Module):
126
- def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
127
- super(DeepUnet, self).__init__()
128
- self.encoder = Encoder(in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels)
129
- self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks)
130
- self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
131
-
132
- def forward(self, x):
133
- x, concat_tensors = self.encoder(x)
134
- return self.decoder(self.intermediate(x), concat_tensors)
135
-
136
- class E2E(nn.Module):
137
- def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
138
- super(E2E, self).__init__()
139
- self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
140
- self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
141
- self.fc = nn.Sequential(BiGRU(3 * 128, 256, n_gru), nn.Linear(512, N_CLASS), nn.Dropout(0.25), nn.Sigmoid()) if n_gru else nn.Sequential(nn.Linear(3 * N_MELS, N_CLASS), nn.Dropout(0.25), nn.Sigmoid())
142
-
143
- def forward(self, mel):
144
- return self.fc(self.cnn(self.unet(mel.transpose(-1, -2).unsqueeze(1))).transpose(1, 2).flatten(-2))
145
-
146
- class MelSpectrogram(torch.nn.Module):
147
- def __init__(self, is_half, n_mel_channels, sample_rate, win_length, hop_length, n_fft=None, mel_fmin=0, mel_fmax=None, clamp=1e-5):
148
- super().__init__()
149
- n_fft = win_length if n_fft is None else n_fft
150
- self.hann_window = {}
151
- mel_basis = mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax, htk=True)
152
- mel_basis = torch.from_numpy(mel_basis).float()
153
- self.register_buffer("mel_basis", mel_basis)
154
- self.n_fft = win_length if n_fft is None else n_fft
155
- self.hop_length = hop_length
156
- self.win_length = win_length
157
- self.sample_rate = sample_rate
158
- self.n_mel_channels = n_mel_channels
159
- self.clamp = clamp
160
- self.is_half = is_half
161
-
162
- def forward(self, audio, keyshift=0, speed=1, center=True):
163
- factor = 2 ** (keyshift / 12)
164
- win_length_new = int(np.round(self.win_length * factor))
165
- keyshift_key = str(keyshift) + "_" + str(audio.device)
166
- if keyshift_key not in self.hann_window: self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
167
-
168
- fft = torch.stft(audio, n_fft=int(np.round(self.n_fft * factor)), hop_length=int(np.round(self.hop_length * speed)), win_length=win_length_new, window=self.hann_window[keyshift_key], center=center, return_complex=True)
169
- magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
170
-
171
- if keyshift != 0:
172
- size = self.n_fft // 2 + 1
173
- resize = magnitude.size(1)
174
-
175
- if resize < size: magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
176
- magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
177
-
178
- mel_output = torch.matmul(self.mel_basis, magnitude)
179
- if self.is_half: mel_output = mel_output.half()
180
-
181
- return torch.log(torch.clamp(mel_output, min=self.clamp))
182
-
183
- class RMVPE:
184
- def __init__(self, model_path, is_half, device=None, providers=None, onnx=False):
185
- self.resample_kernel = {}
186
- self.onnx = onnx
187
-
188
- if self.onnx:
189
- import onnxruntime as ort
190
-
191
- sess_options = ort.SessionOptions()
192
- sess_options.log_severity_level = 3
193
-
194
- self.model = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
195
- else:
196
- model = E2E(4, 1, (2, 2))
197
- ckpt = torch.load(model_path, map_location="cpu")
198
- model.load_state_dict(ckpt)
199
- model.eval()
200
- if is_half: model = model.half()
201
- self.model = model.to(device)
202
-
203
- self.resample_kernel = {}
204
- self.is_half = is_half
205
- self.device = device
206
- self.mel_extractor = MelSpectrogram(is_half, N_MELS, 16000, 1024, 160, None, 30, 8000).to(device)
207
- cents_mapping = 20 * np.arange(N_CLASS) + 1997.3794084376191
208
- self.cents_mapping = np.pad(cents_mapping, (4, 4))
209
-
210
- def mel2hidden(self, mel):
211
- with torch.no_grad():
212
- n_frames = mel.shape[-1]
213
- mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect")
214
- hidden = self.model.run([self.model.get_outputs()[0].name], input_feed={self.model.get_inputs()[0].name: mel.cpu().numpy().astype(np.float32)})[0] if self.onnx else self.model(mel.half() if self.is_half else mel.float())
215
- return hidden[:, :n_frames]
216
-
217
- def decode(self, hidden, thred=0.03):
218
- f0 = 10 * (2 ** (self.to_local_average_cents(hidden, thred=thred) / 1200))
219
- f0[f0 == 10] = 0
220
-
221
- return f0
222
-
223
- def infer_from_audio(self, audio, thred=0.03):
224
- hidden = self.mel2hidden(self.mel_extractor(torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True))
225
-
226
- return self.decode((hidden.squeeze(0).cpu().numpy().astype(np.float32) if self.is_half else hidden.squeeze(0).cpu().numpy()) if not self.onnx else hidden[0], thred=thred)
227
-
228
- def infer_from_audio_with_pitch(self, audio, thred=0.03, f0_min=50, f0_max=1100):
229
- hidden = self.mel2hidden(self.mel_extractor(torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True))
230
-
231
- f0 = self.decode((hidden.squeeze(0).cpu().numpy().astype(np.float32) if self.is_half else hidden.squeeze(0).cpu().numpy()) if not self.onnx else hidden[0], thred=thred)
232
- f0[(f0 < f0_min) | (f0 > f0_max)] = 0
233
-
234
- return f0
235
-
236
- def to_local_average_cents(self, salience, thred=0.05):
237
- center = np.argmax(salience, axis=1)
238
- salience = np.pad(salience, ((0, 0), (4, 4)))
239
- center += 4
240
- todo_salience, todo_cents_mapping = [], []
241
- starts = center - 4
242
- ends = center + 5
243
-
244
- for idx in range(salience.shape[0]):
245
- todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
246
- todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
247
-
248
- todo_salience = np.array(todo_salience)
249
- devided = np.sum(todo_salience * np.array(todo_cents_mapping), 1) / np.sum(todo_salience, 1)
250
- devided[np.max(salience, axis=1) <= thred] = 0
251
-
252
- return devided
253
-
254
- class BiGRU(nn.Module):
255
- def __init__(self, input_features, hidden_features, num_layers):
256
- super(BiGRU, self).__init__()
257
- self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
258
-
259
- def forward(self, x):
260
- return self.gru(x)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/predictors/SWIPE.py DELETED
@@ -1,140 +0,0 @@
1
- import math
2
-
3
- import numpy as np
4
-
5
- from matplotlib import mlab
6
- from scipy import interpolate
7
- from decimal import Decimal, ROUND_HALF_UP
8
-
9
- def swipe(x, fs, f0_floor=50, f0_ceil=1100, frame_period=10, sTHR=0.3):
10
- plim = np.array([f0_floor, f0_ceil])
11
- t = np.arange(0, int(1000 * len(x) / fs / (frame_period) + 1)) * (frame_period / 1000)
12
-
13
- log2pc = np.arange(np.log2(plim[0]) * 96, np.log2(plim[-1]) * 96)
14
- log2pc *= (1 / 96)
15
-
16
- pc = 2 ** log2pc
17
- S = np.zeros((len(pc), len(t)))
18
-
19
- logWs = [round_matlab(elm) for elm in np.log2(4 * 2 * fs / plim)]
20
- ws = 2 ** np.arange(logWs[0], logWs[1] - 1, -1)
21
- p0 = 4 * 2 * fs / ws
22
-
23
- d = 1 + log2pc - np.log2(4 * 2 * fs / ws[0])
24
- fERBs = erbs2hz(np.arange(hz2erbs(pc[0] / 4), hz2erbs(fs / 2), 0.1))
25
-
26
- for i in range(len(ws)):
27
- dn = round_matlab(4 * fs / p0[i])
28
- X, f, ti = mlab.specgram(x=np.r_[np.zeros(int(ws[i] / 2)), np.r_[x, np.zeros(int(dn + ws[i] / 2))]], NFFT=ws[i], Fs=fs, window=np.hanning(ws[i] + 2)[1:-1], noverlap=max(0, np.round(ws[i] - dn)), mode='complex')
29
- ti = np.r_[0, ti[:-1]]
30
- M = np.maximum(0, interpolate.interp1d(f, np.abs(X.T), kind='cubic')(fERBs)).T
31
-
32
- if i == len(ws) - 1:
33
- j = np.where(d - (i + 1) > -1)[0]
34
- k = np.where(d[j] - (i + 1) < 0)[0]
35
- elif i == 0:
36
- j = np.where(d - (i + 1) < 1)[0]
37
- k = np.where(d[j] - (i + 1) > 0)[0]
38
- else:
39
- j = np.where(np.abs(d - (i + 1)) < 1)[0]
40
- k = np.arange(len(j))
41
-
42
- Si = pitchStrengthAllCandidates(fERBs, np.sqrt(M), pc[j])
43
- Si = interpolate.interp1d(ti, Si, bounds_error=False, fill_value='nan')(t) if Si.shape[1] > 1 else np.full((len(Si), len(t)), np.nan)
44
-
45
- mu = np.ones(j.shape)
46
- mu[k] = 1 - np.abs(d[j[k]] - i - 1)
47
- S[j, :] = S[j, :] + np.tile(mu.reshape(-1, 1), (1, Si.shape[1])) * Si
48
-
49
-
50
- p = np.full((S.shape[1], 1), np.nan)
51
- s = np.full((S.shape[1], 1), np.nan)
52
-
53
- for j in range(S.shape[1]):
54
- s[j] = np.max(S[:, j])
55
- i = np.argmax(S[:, j])
56
-
57
- if s[j] < sTHR: continue
58
-
59
- if i == 0: p[j] = pc[0]
60
- elif i == len(pc) - 1: p[j] = pc[0]
61
- else:
62
- I = np.arange(i-1, i+2)
63
- tc = 1 / pc[I]
64
-
65
- ntc = (tc / tc[1] - 1) * 2 * np.pi
66
- idx = np.isfinite(S[I, j])
67
-
68
- c = np.zeros(len(ntc))
69
- c += np.nan
70
-
71
- I_ = I[idx]
72
-
73
- if len(I_) < 2: c[idx] = (S[I, j])[0] / ntc[0]
74
- else: c[idx] = np.polyfit(ntc[idx], (S[I_, j]), 2)
75
-
76
- pval = np.polyval(c, ((1 / (2 ** np.arange(np.log2(pc[I[0]]), np.log2(pc[I[2]]) + 1 / 12 / 64, 1 / 12 / 64))) / tc[1] - 1) * 2 * np.pi)
77
- s[j] = np.max(pval)
78
- p[j] = 2 ** (np.log2(pc[I[0]]) + (np.argmax(pval)) / 12 / 64)
79
-
80
- p = p.flatten()
81
- p[np.isnan(p)] = 0
82
-
83
- return np.array(p, dtype=np.float32), np.array(t, dtype=np.float32)
84
-
85
- def round_matlab(n):
86
- return int(Decimal(n).quantize(0, ROUND_HALF_UP))
87
-
88
- def pitchStrengthAllCandidates(f, L, pc):
89
- den = np.sqrt(np.sum(L * L, axis=0))
90
- den = np.where(den == 0, 2.220446049250313e-16, den)
91
-
92
- L = L / den
93
- S = np.zeros((len(pc), L.shape[1]))
94
-
95
- for j in range(len(pc)):
96
- S[j,:] = pitchStrengthOneCandidate(f, L, pc[j])
97
-
98
- return S
99
-
100
- def pitchStrengthOneCandidate(f, L, pc):
101
- k = np.zeros(len(f))
102
- q = f / pc
103
-
104
- for i in ([1] + sieve(int(np.fix(f[-1] / pc - 0.75)))):
105
- a = np.abs(q - i)
106
- p = a < 0.25
107
- k[p] = np.cos(2 * np.pi * q[p])
108
-
109
- v = np.logical_and((0.25 < a), (a < 0.75))
110
- k[v] = k[v] + np.cos(2 * np.pi * q[v]) / 2
111
-
112
- k *= np.sqrt(1 / f)
113
- k /= np.linalg.norm(k[k>0])
114
-
115
- return k @ L
116
-
117
- def hz2erbs(hz):
118
- return 21.4 * np.log10(1 + hz / 229)
119
-
120
- def erbs2hz(erbs):
121
- return (10 ** (erbs / 21.4) - 1) * 229
122
-
123
- def sieve(n):
124
- primes = list(range(2, n + 1))
125
- num = 2
126
-
127
- while num < math.sqrt(n):
128
- i = num
129
-
130
- while i <= n:
131
- i += num
132
-
133
- if i in primes: primes.remove(i)
134
-
135
- for j in primes:
136
- if j > num:
137
- num = j
138
- break
139
-
140
- return primes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/predictors/WORLD_WRAPPER.py DELETED
@@ -1,90 +0,0 @@
1
- import os
2
- import torch
3
- import ctypes
4
- import platform
5
-
6
- import numpy as np
7
-
8
-
9
-
10
- class DioOption(ctypes.Structure):
11
- _fields_ = [("F0Floor", ctypes.c_double), ("F0Ceil", ctypes.c_double), ("ChannelsInOctave", ctypes.c_double), ("FramePeriod", ctypes.c_double), ("Speed", ctypes.c_int), ("AllowedRange", ctypes.c_double)]
12
-
13
- class HarvestOption(ctypes.Structure):
14
- _fields_ = [("F0Floor", ctypes.c_double), ("F0Ceil", ctypes.c_double), ("FramePeriod", ctypes.c_double)]
15
-
16
- class PYWORLD:
17
- def __init__(self):
18
- self.world_path = os.path.join("assets", "models", "predictors", "world")
19
- os.makedirs(self.world_path, exist_ok=True)
20
-
21
- model_type, suffix = (("world_64" if platform.architecture()[0] == "64bit" else "world_86"), ".dll") if platform.system() == "Windows" else ("world_linux", ".so")
22
- self.world_file_path = os.path.join(self.world_path, f"{model_type}{suffix}")
23
-
24
- if not os.path.exists(self.world_file_path):
25
- model = torch.load(os.path.join("assets", "models", "predictors", "world.pth"), map_location="cpu")
26
-
27
- with open(self.world_file_path, "wb") as w:
28
- w.write(model[model_type])
29
-
30
- self.world_dll = ctypes.CDLL(self.world_file_path)
31
-
32
- def harvest(self, x, fs, f0_floor=50, f0_ceil=1100, frame_period=10):
33
- self.world_dll.Harvest.argtypes = [ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.c_int, ctypes.POINTER(HarvestOption), ctypes.POINTER(ctypes.c_double), ctypes.POINTER(ctypes.c_double)]
34
- self.world_dll.Harvest.restype = None
35
-
36
- self.world_dll.InitializeHarvestOption.argtypes = [ctypes.POINTER(HarvestOption)]
37
- self.world_dll.InitializeHarvestOption.restype = None
38
-
39
- self.world_dll.GetSamplesForHarvest.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_double]
40
- self.world_dll.GetSamplesForHarvest.restype = ctypes.c_int
41
-
42
- option = HarvestOption()
43
- self.world_dll.InitializeHarvestOption(ctypes.byref(option))
44
-
45
- option.F0Floor = f0_floor
46
- option.F0Ceil = f0_ceil
47
- option.FramePeriod = frame_period
48
-
49
- f0_length = self.world_dll.GetSamplesForHarvest(fs, len(x), option.FramePeriod)
50
- f0 = (ctypes.c_double * f0_length)()
51
- tpos = (ctypes.c_double * f0_length)()
52
-
53
- self.world_dll.Harvest((ctypes.c_double * len(x))(*x), len(x), fs, ctypes.byref(option), tpos, f0)
54
- return np.array(f0, dtype=np.float32), np.array(tpos, dtype=np.float32)
55
-
56
- def dio(self, x, fs, f0_floor=50, f0_ceil=1100, channels_in_octave=2, frame_period=10, speed=1, allowed_range=0.1):
57
- self.world_dll.Dio.argtypes = [ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.c_int, ctypes.POINTER(DioOption), ctypes.POINTER(ctypes.c_double), ctypes.POINTER(ctypes.c_double)]
58
- self.world_dll.Dio.restype = None
59
-
60
- self.world_dll.InitializeDioOption.argtypes = [ctypes.POINTER(DioOption)]
61
- self.world_dll.InitializeDioOption.restype = None
62
-
63
- self.world_dll.GetSamplesForDIO.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_double]
64
- self.world_dll.GetSamplesForDIO.restype = ctypes.c_int
65
-
66
- option = DioOption()
67
- self.world_dll.InitializeDioOption(ctypes.byref(option))
68
-
69
- option.F0Floor = f0_floor
70
- option.F0Ceil = f0_ceil
71
- option.ChannelsInOctave = channels_in_octave
72
- option.FramePeriod = frame_period
73
- option.Speed = speed
74
- option.AllowedRange = allowed_range
75
-
76
- f0_length = self.world_dll.GetSamplesForDIO(fs, len(x), option.FramePeriod)
77
- f0 = (ctypes.c_double * f0_length)()
78
- tpos = (ctypes.c_double * f0_length)()
79
-
80
- self.world_dll.Dio((ctypes.c_double * len(x))(*x), len(x), fs, ctypes.byref(option), tpos, f0)
81
- return np.array(f0, dtype=np.float32), np.array(tpos, dtype=np.float32)
82
-
83
- def stonemask(self, x, fs, tpos, f0):
84
- self.world_dll.StoneMask.argtypes = [ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_double), ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.POINTER(ctypes.c_double)]
85
- self.world_dll.StoneMask.restype = None
86
-
87
- out_f0 = (ctypes.c_double * len(f0))()
88
- self.world_dll.StoneMask((ctypes.c_double * len(x))(*x), len(x), fs, (ctypes.c_double * len(tpos))(*tpos), (ctypes.c_double * len(f0))(*f0), len(f0), out_f0)
89
-
90
- return np.array(out_f0, dtype=np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/speaker_diarization/ECAPA_TDNN.py DELETED
@@ -1,280 +0,0 @@
1
- import math
2
- import torch
3
-
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
-
7
- def length_to_mask(length, max_len=None, dtype=None, device=None):
8
- assert len(length.shape) == 1
9
-
10
- if max_len is None: max_len = length.max().long().item()
11
-
12
- mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand(len(length), max_len) < length.unsqueeze(1)
13
-
14
- if dtype is None: dtype = length.dtype
15
- if device is None: device = length.device
16
-
17
- return torch.as_tensor(mask, dtype=dtype, device=device)
18
-
19
- def get_padding_elem(L_in, stride, kernel_size, dilation):
20
- if stride > 1: padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)]
21
- else:
22
- L_out = (math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1)
23
- padding = [math.floor((L_in - L_out) / 2), math.floor((L_in - L_out) / 2)]
24
-
25
- return padding
26
-
27
- class _BatchNorm1d(nn.Module):
28
- def __init__(self, input_shape=None, input_size=None, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, combine_batch_time=False, skip_transpose=False):
29
- super().__init__()
30
- self.combine_batch_time = combine_batch_time
31
- self.skip_transpose = skip_transpose
32
-
33
- if input_size is None and skip_transpose: input_size = input_shape[1]
34
- elif input_size is None: input_size = input_shape[-1]
35
-
36
- self.norm = nn.BatchNorm1d(input_size, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
37
-
38
- def forward(self, x):
39
- shape_or = x.shape
40
-
41
- if self.combine_batch_time:x = x.reshape(shape_or[0] * shape_or[1], shape_or[2]) if x.ndim == 3 else x.reshape(shape_or[0] * shape_or[1], shape_or[3], shape_or[2])
42
- elif not self.skip_transpose: x = x.transpose(-1, 1)
43
-
44
- x_n = self.norm(x)
45
-
46
- if self.combine_batch_time: x_n = x_n.reshape(shape_or)
47
- elif not self.skip_transpose: x_n = x_n.transpose(1, -1)
48
-
49
- return x_n
50
-
51
- class _Conv1d(nn.Module):
52
- def __init__(self, out_channels, kernel_size, input_shape=None, in_channels=None, stride=1, dilation=1, padding="same", groups=1, bias=True, padding_mode="reflect", skip_transpose=False, weight_norm=False, conv_init=None, default_padding=0):
53
- super().__init__()
54
- self.kernel_size = kernel_size
55
- self.stride = stride
56
- self.dilation = dilation
57
- self.padding = padding
58
- self.padding_mode = padding_mode
59
- self.unsqueeze = False
60
- self.skip_transpose = skip_transpose
61
-
62
- if input_shape is None and in_channels is None: raise ValueError
63
- if in_channels is None: in_channels = self._check_input_shape(input_shape)
64
-
65
- self.in_channels = in_channels
66
- self.conv = nn.Conv1d(in_channels, out_channels, self.kernel_size, stride=self.stride, dilation=self.dilation, padding=default_padding, groups=groups, bias=bias)
67
-
68
- if conv_init == "kaiming": nn.init.kaiming_normal_(self.conv.weight)
69
- elif conv_init == "zero": nn.init.zeros_(self.conv.weight)
70
- elif conv_init == "normal": nn.init.normal_(self.conv.weight, std=1e-6)
71
-
72
- if weight_norm: self.conv = nn.utils.weight_norm(self.conv)
73
-
74
- def forward(self, x):
75
- if not self.skip_transpose: x = x.transpose(1, -1)
76
- if self.unsqueeze: x = x.unsqueeze(1)
77
-
78
- if self.padding == "same": x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride)
79
- elif self.padding == "causal": x = F.pad(x, ((self.kernel_size - 1) * self.dilation, 0))
80
- elif self.padding == "valid": pass
81
- else: raise ValueError
82
-
83
- wx = self.conv(x)
84
-
85
- if self.unsqueeze: wx = wx.squeeze(1)
86
- if not self.skip_transpose: wx = wx.transpose(1, -1)
87
-
88
- return wx
89
-
90
- def _manage_padding(self, x, kernel_size, dilation, stride):
91
- return F.pad(x, get_padding_elem(self.in_channels, stride, kernel_size, dilation), mode=self.padding_mode)
92
-
93
- def _check_input_shape(self, shape):
94
- if len(shape) == 2:
95
- self.unsqueeze = True
96
- in_channels = 1
97
- elif self.skip_transpose: in_channels = shape[1]
98
- elif len(shape) == 3: in_channels = shape[2]
99
- else: raise ValueError
100
-
101
- if not self.padding == "valid" and self.kernel_size % 2 == 0: raise ValueError
102
- return in_channels
103
-
104
- def remove_weight_norm(self):
105
- self.conv = nn.utils.remove_weight_norm(self.conv)
106
-
107
- class Linear(torch.nn.Module):
108
- def __init__(self, n_neurons, input_shape=None, input_size=None, bias=True, max_norm=None, combine_dims=False):
109
- super().__init__()
110
- self.max_norm = max_norm
111
- self.combine_dims = combine_dims
112
-
113
- if input_shape is None and input_size is None: raise ValueError
114
- if input_size is None:
115
- input_size = input_shape[-1]
116
- if len(input_shape) == 4 and self.combine_dims: input_size = input_shape[2] * input_shape[3]
117
-
118
- self.w = nn.Linear(input_size, n_neurons, bias=bias)
119
-
120
- def forward(self, x):
121
- if x.ndim == 4 and self.combine_dims: x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
122
- if self.max_norm is not None: self.w.weight.data = torch.renorm(self.w.weight.data, p=2, dim=0, maxnorm=self.max_norm)
123
-
124
- return self.w(x)
125
-
126
- class Conv1d(_Conv1d):
127
- def __init__(self, *args, **kwargs):
128
- super().__init__(skip_transpose=True, *args, **kwargs)
129
-
130
- class BatchNorm1d(_BatchNorm1d):
131
- def __init__(self, *args, **kwargs):
132
- super().__init__(skip_transpose=True, *args, **kwargs)
133
-
134
- class TDNNBlock(nn.Module):
135
- def __init__(self, in_channels, out_channels, kernel_size, dilation, activation=nn.ReLU, groups=1, dropout=0.0):
136
- super().__init__()
137
- self.conv = Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, dilation=dilation, groups=groups)
138
- self.activation = activation()
139
- self.norm = BatchNorm1d(input_size=out_channels)
140
- self.dropout = nn.Dropout1d(p=dropout)
141
-
142
- def forward(self, x):
143
- return self.dropout(self.norm(self.activation(self.conv(x))))
144
-
145
- class Res2NetBlock(torch.nn.Module):
146
- def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1, dropout=0.0):
147
- super().__init__()
148
- assert in_channels % scale == 0
149
- assert out_channels % scale == 0
150
- in_channel = in_channels // scale
151
- hidden_channel = out_channels // scale
152
- self.blocks = nn.ModuleList([TDNNBlock(in_channel, hidden_channel, kernel_size=kernel_size, dilation=dilation, dropout=dropout) for _ in range(scale - 1)])
153
- self.scale = scale
154
-
155
- def forward(self, x):
156
- y = []
157
-
158
- for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
159
- if i == 0: y_i = x_i
160
- elif i == 1: y_i = self.blocks[i - 1](x_i)
161
- else: y_i = self.blocks[i - 1](x_i + y_i)
162
-
163
- y.append(y_i)
164
-
165
- return torch.cat(y, dim=1)
166
-
167
- class SEBlock(nn.Module):
168
- def __init__(self, in_channels, se_channels, out_channels):
169
- super().__init__()
170
-
171
- self.conv1 = Conv1d(in_channels=in_channels, out_channels=se_channels, kernel_size=1)
172
- self.relu = torch.nn.ReLU(inplace=True)
173
- self.conv2 = Conv1d(in_channels=se_channels, out_channels=out_channels, kernel_size=1)
174
- self.sigmoid = torch.nn.Sigmoid()
175
-
176
- def forward(self, x, lengths=None):
177
- L = x.shape[-1]
178
-
179
- if lengths is not None:
180
- mask = length_to_mask(lengths * L, max_len=L, device=x.device).unsqueeze(1)
181
- s = (x * mask).sum(dim=2, keepdim=True) / mask.sum(dim=2, keepdim=True)
182
- else: s = x.mean(dim=2, keepdim=True)
183
-
184
- return self.sigmoid(self.conv2(self.relu(self.conv1(s)))) * x
185
-
186
- class AttentiveStatisticsPooling(nn.Module):
187
- def __init__(self, channels, attention_channels=128, global_context=True):
188
- super().__init__()
189
- self.eps = 1e-12
190
- self.global_context = global_context
191
- self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1) if global_context else TDNNBlock(channels, attention_channels, 1, 1)
192
- self.tanh = nn.Tanh()
193
- self.conv = Conv1d(in_channels=attention_channels, out_channels=channels, kernel_size=1)
194
-
195
- def forward(self, x, lengths=None):
196
- L = x.shape[-1]
197
-
198
- def _compute_statistics(x, m, dim=2, eps=self.eps):
199
- mean = (m * x).sum(dim)
200
- return mean, torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
201
-
202
- if lengths is None: lengths = torch.ones(x.shape[0], device=x.device)
203
- mask = length_to_mask(lengths * L, max_len=L, device=x.device).unsqueeze(1)
204
-
205
- if self.global_context:
206
- mean, std = _compute_statistics(x, mask / mask.sum(dim=2, keepdim=True).float())
207
- attn = torch.cat([x, mean.unsqueeze(2).repeat(1, 1, L), std.unsqueeze(2).repeat(1, 1, L)], dim=1)
208
- else: attn = x
209
-
210
- mean, std = _compute_statistics(x, F.softmax(self.conv(self.tanh(self.tdnn(attn))).masked_fill(mask == 0, float("-inf")), dim=2))
211
- return torch.cat((mean, std), dim=1).unsqueeze(2)
212
-
213
- class SERes2NetBlock(nn.Module):
214
- def __init__(self, in_channels, out_channels, res2net_scale=8, se_channels=128, kernel_size=1, dilation=1, activation=torch.nn.ReLU, groups=1, dropout=0.0):
215
- super().__init__()
216
- self.out_channels = out_channels
217
- self.tdnn1 = TDNNBlock(in_channels, out_channels, kernel_size=1, dilation=1, activation=activation, groups=groups, dropout=dropout)
218
- self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation)
219
- self.tdnn2 = TDNNBlock(out_channels, out_channels, kernel_size=1, dilation=1, activation=activation, groups=groups, dropout=dropout)
220
- self.se_block = SEBlock(out_channels, se_channels, out_channels)
221
-
222
- self.shortcut = None
223
- if in_channels != out_channels: self.shortcut = Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
224
-
225
- def forward(self, x, lengths=None):
226
- residual = x
227
- if self.shortcut: residual = self.shortcut(x)
228
-
229
- return self.se_block(self.tdnn2(self.res2net_block(self.tdnn1(x))), lengths) + residual
230
-
231
- class ECAPA_TDNN(torch.nn.Module):
232
- def __init__(self, input_size, device="cpu", lin_neurons=192, activation=torch.nn.ReLU, channels=[512, 512, 512, 512, 1536], kernel_sizes=[5, 3, 3, 3, 1], dilations=[1, 2, 3, 4, 1], attention_channels=128, res2net_scale=8, se_channels=128, global_context=True, groups=[1, 1, 1, 1, 1], dropout=0.0):
233
- super().__init__()
234
- assert len(channels) == len(kernel_sizes)
235
- assert len(channels) == len(dilations)
236
-
237
- self.channels = channels
238
- self.blocks = nn.ModuleList()
239
-
240
- self.blocks.append(TDNNBlock(input_size, channels[0], kernel_sizes[0], dilations[0], activation, groups[0], dropout))
241
-
242
- for i in range(1, len(channels) - 1):
243
- self.blocks.append(SERes2NetBlock(channels[i - 1], channels[i], res2net_scale=res2net_scale, se_channels=se_channels, kernel_size=kernel_sizes[i], dilation=dilations[i], activation=activation, groups=groups[i], dropout=dropout))
244
-
245
- self.mfa = TDNNBlock(channels[-2] * (len(channels) - 2), channels[-1], kernel_sizes[-1], dilations[-1], activation, groups=groups[-1], dropout=dropout)
246
- self.asp = AttentiveStatisticsPooling(channels[-1], attention_channels=attention_channels, global_context=global_context)
247
- self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2)
248
- self.fc = Conv1d(in_channels=channels[-1] * 2, out_channels=lin_neurons, kernel_size=1)
249
-
250
- def forward(self, x, lengths=None):
251
- x = x.transpose(1, 2)
252
-
253
- xl = []
254
- for layer in self.blocks:
255
- try:
256
- x = layer(x, lengths=lengths)
257
- except TypeError:
258
- x = layer(x)
259
-
260
- xl.append(x)
261
-
262
- return self.fc(self.asp_bn(self.asp(self.mfa(torch.cat(xl[1:], dim=1)), lengths=lengths))).transpose(1, 2)
263
-
264
- class Classifier(torch.nn.Module):
265
- def __init__(self, input_size, device="cpu", lin_blocks=0, lin_neurons=192, out_neurons=1211):
266
- super().__init__()
267
- self.blocks = nn.ModuleList()
268
-
269
- for _ in range(lin_blocks):
270
- self.blocks.extend([_BatchNorm1d(input_size=input_size), Linear(input_size=input_size, n_neurons=lin_neurons)])
271
- input_size = lin_neurons
272
-
273
- self.weight = nn.Parameter(torch.FloatTensor(out_neurons, input_size, device=device))
274
- nn.init.xavier_uniform_(self.weight)
275
-
276
- def forward(self, x):
277
- for layer in self.blocks:
278
- x = layer(x)
279
-
280
- return F.linear(F.normalize(x.squeeze(1)), F.normalize(self.weight)).unsqueeze(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/speaker_diarization/audio.py DELETED
@@ -1,170 +0,0 @@
1
- import os
2
- import math
3
- import random
4
- import torchaudio
5
-
6
- from io import IOBase
7
- from torch.nn.functional import pad
8
-
9
- def get_torchaudio_info(file, backend = None):
10
- if not backend:
11
- backends = (torchaudio.list_audio_backends())
12
- backend = "soundfile" if "soundfile" in backends else backends[0]
13
-
14
- info = torchaudio.info(file["audio"], backend=backend)
15
- if isinstance(file["audio"], IOBase): file["audio"].seek(0)
16
-
17
- return info
18
-
19
- class Audio:
20
- @staticmethod
21
- def power_normalize(waveform):
22
- return waveform / (waveform.square().mean(dim=-1, keepdim=True).sqrt() + 1e-8)
23
-
24
- @staticmethod
25
- def validate_file(file):
26
- if isinstance(file, (str, os.PathLike)): file = {"audio": str(file), "uri": os.path.splitext(os.path.basename(file))[0]}
27
- elif isinstance(file, IOBase): return {"audio": file, "uri": "stream"}
28
- else: raise ValueError
29
-
30
- if "waveform" in file:
31
- waveform = file["waveform"]
32
- if len(waveform.shape) != 2 or waveform.shape[0] > waveform.shape[1]: raise ValueError
33
-
34
- sample_rate: int = file.get("sample_rate", None)
35
- if sample_rate is None: raise ValueError
36
-
37
- file.setdefault("uri", "waveform")
38
-
39
- elif "audio" in file:
40
- if isinstance(file["audio"], IOBase): return file
41
-
42
- path = os.path.abspath(file["audio"])
43
- file.setdefault("uri", os.path.splitext(os.path.basename(path))[0])
44
-
45
- else: raise ValueError
46
-
47
- return file
48
-
49
- def __init__(self, sample_rate: int = None, mono=None, backend: str = None):
50
- super().__init__()
51
- self.sample_rate = sample_rate
52
- self.mono = mono
53
-
54
- if not backend:
55
- backends = (torchaudio.list_audio_backends())
56
- backend = "soundfile" if "soundfile" in backends else backends[0]
57
-
58
- self.backend = backend
59
-
60
- def downmix_and_resample(self, waveform, sample_rate):
61
- num_channels = waveform.shape[0]
62
-
63
- if num_channels > 1:
64
- if self.mono == "random":
65
- channel = random.randint(0, num_channels - 1)
66
- waveform = waveform[channel : channel + 1]
67
- elif self.mono == "downmix": waveform = waveform.mean(dim=0, keepdim=True)
68
-
69
- if (self.sample_rate is not None) and (self.sample_rate != sample_rate):
70
- waveform = torchaudio.functional.resample(waveform, sample_rate, self.sample_rate)
71
- sample_rate = self.sample_rate
72
-
73
- return waveform, sample_rate
74
-
75
- def get_duration(self, file):
76
- file = self.validate_file(file)
77
-
78
- if "waveform" in file:
79
- frames = len(file["waveform"].T)
80
- sample_rate = file["sample_rate"]
81
- else:
82
- info = file["torchaudio.info"] if "torchaudio.info" in file else get_torchaudio_info(file, backend=self.backend)
83
- frames = info.num_frames
84
- sample_rate = info.sample_rate
85
-
86
- return frames / sample_rate
87
-
88
- def get_num_samples(self, duration, sample_rate = None):
89
- sample_rate = sample_rate or self.sample_rate
90
- if sample_rate is None: raise ValueError
91
-
92
- return math.floor(duration * sample_rate)
93
-
94
- def __call__(self, file):
95
- file = self.validate_file(file)
96
-
97
- if "waveform" in file:
98
- waveform = file["waveform"]
99
- sample_rate = file["sample_rate"]
100
- elif "audio" in file:
101
- waveform, sample_rate = torchaudio.load(file["audio"], backend=self.backend)
102
- if isinstance(file["audio"], IOBase): file["audio"].seek(0)
103
-
104
- channel = file.get("channel", None)
105
- if channel is not None: waveform = waveform[channel : channel + 1]
106
-
107
- return self.downmix_and_resample(waveform, sample_rate)
108
-
109
- def crop(self, file, segment, duration = None, mode="raise"):
110
- file = self.validate_file(file)
111
-
112
- if "waveform" in file:
113
- waveform = file["waveform"]
114
- frames = waveform.shape[1]
115
- sample_rate = file["sample_rate"]
116
- elif "torchaudio.info" in file:
117
- info = file["torchaudio.info"]
118
- frames = info.num_frames
119
- sample_rate = info.sample_rate
120
- else:
121
- info = get_torchaudio_info(file, backend=self.backend)
122
- frames = info.num_frames
123
- sample_rate = info.sample_rate
124
-
125
- channel = file.get("channel", None)
126
- start_frame = math.floor(segment.start * sample_rate)
127
-
128
- if duration:
129
- num_frames = math.floor(duration * sample_rate)
130
- end_frame = start_frame + num_frames
131
- else:
132
- end_frame = math.floor(segment.end * sample_rate)
133
- num_frames = end_frame - start_frame
134
-
135
- if mode == "raise":
136
- if num_frames > frames: raise ValueError
137
-
138
- if end_frame > frames + math.ceil(0.001 * sample_rate): raise ValueError
139
- else:
140
- end_frame = min(end_frame, frames)
141
- start_frame = end_frame - num_frames
142
-
143
- if start_frame < 0: raise ValueError
144
- elif mode == "pad":
145
- pad_start = -min(0, start_frame)
146
- pad_end = max(end_frame, frames) - frames
147
-
148
- start_frame = max(0, start_frame)
149
- end_frame = min(end_frame, frames)
150
-
151
- num_frames = end_frame - start_frame
152
-
153
- if "waveform" in file: data = file["waveform"][:, start_frame:end_frame]
154
- else:
155
- try:
156
- data, _ = torchaudio.load(file["audio"], frame_offset=start_frame, num_frames=num_frames, backend=self.backend)
157
- if isinstance(file["audio"], IOBase): file["audio"].seek(0)
158
- except RuntimeError:
159
- if isinstance(file["audio"], IOBase): raise RuntimeError
160
-
161
- waveform, sample_rate = self.__call__(file)
162
- data = waveform[:, start_frame:end_frame]
163
-
164
- file["waveform"] = waveform
165
- file["sample_rate"] = sample_rate
166
-
167
- if channel is not None: data = data[channel : channel + 1, :]
168
- if mode == "pad": data = pad(data, (pad_start, pad_end))
169
-
170
- return self.downmix_and_resample(data, sample_rate)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/speaker_diarization/embedding.py DELETED
@@ -1,90 +0,0 @@
1
- import os
2
- import sys
3
- import torch
4
-
5
- import numpy as np
6
- import torch.nn.functional as F
7
-
8
- from functools import cached_property
9
- from torch.nn.utils.rnn import pad_sequence
10
-
11
- sys.path.append(os.getcwd())
12
-
13
- from main.library.speaker_diarization.speechbrain import EncoderClassifier
14
-
15
- class BaseInference:
16
- pass
17
-
18
- class SpeechBrainPretrainedSpeakerEmbedding(BaseInference):
19
- def __init__(self, embedding = "assets/models/speaker_diarization/models/speechbrain", device = None):
20
- super().__init__()
21
-
22
- self.embedding = embedding
23
- self.device = device or torch.device("cpu")
24
- self.classifier_ = EncoderClassifier.from_hparams(source=self.embedding, run_opts={"device": self.device})
25
-
26
- def to(self, device):
27
- if not isinstance(device, torch.device): raise TypeError
28
-
29
- self.classifier_ = EncoderClassifier.from_hparams(source=self.embedding, run_opts={"device": device})
30
- self.device = device
31
- return self
32
-
33
- @cached_property
34
- def sample_rate(self):
35
- return self.classifier_.audio_normalizer.sample_rate
36
-
37
- @cached_property
38
- def dimension(self):
39
- *_, dimension = self.classifier_.encode_batch(torch.rand(1, 16000).to(self.device)).shape
40
- return dimension
41
-
42
- @cached_property
43
- def metric(self):
44
- return "cosine"
45
-
46
- @cached_property
47
- def min_num_samples(self):
48
- with torch.inference_mode():
49
- lower, upper = 2, round(0.5 * self.sample_rate)
50
- middle = (lower + upper) // 2
51
-
52
- while lower + 1 < upper:
53
- try:
54
- _ = self.classifier_.encode_batch(torch.randn(1, middle).to(self.device))
55
- upper = middle
56
- except RuntimeError:
57
- lower = middle
58
-
59
- middle = (lower + upper) // 2
60
-
61
- return upper
62
-
63
- def __call__(self, waveforms, masks = None):
64
- batch_size, num_channels, num_samples = waveforms.shape
65
- assert num_channels == 1
66
-
67
- waveforms = waveforms.squeeze(dim=1)
68
-
69
- if masks is None:
70
- signals = waveforms.squeeze(dim=1)
71
- wav_lens = signals.shape[1] * torch.ones(batch_size)
72
- else:
73
- batch_size_masks, _ = masks.shape
74
- assert batch_size == batch_size_masks
75
-
76
- imasks = F.interpolate(masks.unsqueeze(dim=1), size=num_samples, mode="nearest").squeeze(dim=1) > 0.5
77
- signals = pad_sequence([waveform[imask].contiguous() for waveform, imask in zip(waveforms, imasks)], batch_first=True)
78
- wav_lens = imasks.sum(dim=1)
79
-
80
- max_len = wav_lens.max()
81
- if max_len < self.min_num_samples: return np.nan * np.zeros((batch_size, self.dimension))
82
-
83
- too_short = wav_lens < self.min_num_samples
84
- wav_lens = wav_lens / max_len
85
- wav_lens[too_short] = 1.0
86
-
87
- embeddings = (self.classifier_.encode_batch(signals, wav_lens=wav_lens).squeeze(dim=1).cpu().numpy())
88
- embeddings[too_short.cpu().numpy()] = np.nan
89
-
90
- return embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/speaker_diarization/encoder.py DELETED
@@ -1,250 +0,0 @@
1
- import os
2
- import sys
3
- import ast
4
- import torch
5
- import itertools
6
- import collections
7
-
8
- sys.path.append(os.getcwd())
9
-
10
- from main.library.speaker_diarization.speechbrain import if_main_process, ddp_barrier
11
- from main.library.speaker_diarization.features import register_checkpoint_hooks, mark_as_saver, mark_as_loader
12
-
13
- DEFAULT_UNK = "<unk>"
14
- DEFAULT_BOS = "<bos>"
15
- DEFAULT_EOS = "<eos>"
16
- DEFAULT_BLANK = "<blank>"
17
-
18
- @register_checkpoint_hooks
19
- class CategoricalEncoder:
20
- VALUE_SEPARATOR = " => "
21
- EXTRAS_SEPARATOR = "================\n"
22
-
23
- def __init__(self, starting_index=0, **special_labels):
24
- self.lab2ind = {}
25
- self.ind2lab = {}
26
- self.starting_index = starting_index
27
- self.handle_special_labels(special_labels)
28
-
29
- def handle_special_labels(self, special_labels):
30
- if "unk_label" in special_labels: self.add_unk(special_labels["unk_label"])
31
-
32
- def __len__(self):
33
- return len(self.lab2ind)
34
-
35
- @classmethod
36
- def from_saved(cls, path):
37
- obj = cls()
38
- obj.load(path)
39
- return obj
40
-
41
- def update_from_iterable(self, iterable, sequence_input=False):
42
- label_iterator = itertools.chain.from_iterable(iterable) if sequence_input else iter(iterable)
43
- for label in label_iterator:
44
- self.ensure_label(label)
45
-
46
- def update_from_didataset(self, didataset, output_key, sequence_input=False):
47
- with didataset.output_keys_as([output_key]):
48
- self.update_from_iterable((data_point[output_key] for data_point in didataset), sequence_input=sequence_input)
49
-
50
- def limited_labelset_from_iterable(self, iterable, sequence_input=False, n_most_common=None, min_count=1):
51
- label_iterator = itertools.chain.from_iterable(iterable) if sequence_input else iter(iterable)
52
- counts = collections.Counter(label_iterator)
53
-
54
- for label, count in counts.most_common(n_most_common):
55
- if count < min_count: break
56
- self.add_label(label)
57
-
58
- return counts
59
-
60
- def load_or_create(self, path, from_iterables=[], from_didatasets=[], sequence_input=False, output_key=None, special_labels={}):
61
- try:
62
- if if_main_process():
63
- if not self.load_if_possible(path):
64
- for iterable in from_iterables:
65
- self.update_from_iterable(iterable, sequence_input)
66
-
67
- for didataset in from_didatasets:
68
- if output_key is None: raise ValueError
69
- self.update_from_didataset(didataset, output_key, sequence_input)
70
-
71
- self.handle_special_labels(special_labels)
72
- self.save(path)
73
- finally:
74
- ddp_barrier()
75
- self.load(path)
76
-
77
- def add_label(self, label):
78
- if label in self.lab2ind: raise KeyError
79
- index = self._next_index()
80
-
81
- self.lab2ind[label] = index
82
- self.ind2lab[index] = label
83
-
84
- return index
85
-
86
- def ensure_label(self, label):
87
- if label in self.lab2ind: return self.lab2ind[label]
88
- else: return self.add_label(label)
89
-
90
- def insert_label(self, label, index):
91
- if label in self.lab2ind: raise KeyError
92
- else: self.enforce_label(label, index)
93
-
94
- def enforce_label(self, label, index):
95
- index = int(index)
96
-
97
- if label in self.lab2ind:
98
- if index == self.lab2ind[label]: return
99
- else: del self.ind2lab[self.lab2ind[label]]
100
-
101
- if index in self.ind2lab:
102
- saved_label = self.ind2lab[index]
103
- moving_other = True
104
- else: moving_other = False
105
-
106
- self.lab2ind[label] = index
107
- self.ind2lab[index] = label
108
-
109
- if moving_other:
110
- new_index = self._next_index()
111
- self.lab2ind[saved_label] = new_index
112
- self.ind2lab[new_index] = saved_label
113
-
114
- def add_unk(self, unk_label=DEFAULT_UNK):
115
- self.unk_label = unk_label
116
- return self.add_label(unk_label)
117
-
118
- def _next_index(self):
119
- index = self.starting_index
120
- while index in self.ind2lab:
121
- index += 1
122
-
123
- return index
124
-
125
- def is_continuous(self):
126
- indices = sorted(self.ind2lab.keys())
127
- return self.starting_index in indices and all(j - i == 1 for i, j in zip(indices[:-1], indices[1:]))
128
-
129
- def encode_label(self, label, allow_unk=True):
130
- self._assert_len()
131
-
132
- try:
133
- return self.lab2ind[label]
134
- except KeyError:
135
- if hasattr(self, "unk_label") and allow_unk: return self.lab2ind[self.unk_label]
136
- elif hasattr(self, "unk_label") and not allow_unk: raise KeyError
137
- elif not hasattr(self, "unk_label") and allow_unk: raise KeyError
138
- else: raise KeyError
139
-
140
- def encode_label_torch(self, label, allow_unk=True):
141
- return torch.LongTensor([self.encode_label(label, allow_unk)])
142
-
143
- def encode_sequence(self, sequence, allow_unk=True):
144
- self._assert_len()
145
- return [self.encode_label(label, allow_unk) for label in sequence]
146
-
147
- def encode_sequence_torch(self, sequence, allow_unk=True):
148
- return torch.LongTensor([self.encode_label(label, allow_unk) for label in sequence])
149
-
150
- def decode_torch(self, x):
151
- self._assert_len()
152
- decoded = []
153
-
154
- if x.ndim == 1:
155
- for element in x:
156
- decoded.append(self.ind2lab[int(element)])
157
- else:
158
- for subtensor in x:
159
- decoded.append(self.decode_torch(subtensor))
160
-
161
- return decoded
162
-
163
- def decode_ndim(self, x):
164
- self._assert_len()
165
- try:
166
- decoded = []
167
- for subtensor in x:
168
- decoded.append(self.decode_ndim(subtensor))
169
-
170
- return decoded
171
- except TypeError:
172
- return self.ind2lab[int(x)]
173
-
174
- @mark_as_saver
175
- def save(self, path):
176
- self._save_literal(path, self.lab2ind, self._get_extras())
177
-
178
- def load(self, path):
179
- lab2ind, ind2lab, extras = self._load_literal(path)
180
- self.lab2ind = lab2ind
181
- self.ind2lab = ind2lab
182
- self._set_extras(extras)
183
-
184
- @mark_as_loader
185
- def load_if_possible(self, path, end_of_epoch=False):
186
- del end_of_epoch
187
-
188
- try:
189
- self.load(path)
190
- except FileNotFoundError:
191
- return False
192
- except (ValueError, SyntaxError):
193
- return False
194
-
195
- return True
196
-
197
- def expect_len(self, expected_len):
198
- self.expected_len = expected_len
199
-
200
- def ignore_len(self):
201
- self.expected_len = None
202
-
203
- def _assert_len(self):
204
- if hasattr(self, "expected_len"):
205
- if self.expected_len is None: return
206
- if len(self) != self.expected_len: raise RuntimeError
207
- else:
208
- self.ignore_len()
209
- return
210
-
211
- def _get_extras(self):
212
- extras = {"starting_index": self.starting_index}
213
- if hasattr(self, "unk_label"): extras["unk_label"] = self.unk_label
214
-
215
- return extras
216
-
217
- def _set_extras(self, extras):
218
- if "unk_label" in extras: self.unk_label = extras["unk_label"]
219
- self.starting_index = extras["starting_index"]
220
-
221
- @staticmethod
222
- def _save_literal(path, lab2ind, extras):
223
- with open(path, "w", encoding="utf-8") as f:
224
- for label, ind in lab2ind.items():
225
- f.write(repr(label) + CategoricalEncoder.VALUE_SEPARATOR + str(ind) + "\n")
226
-
227
- f.write(CategoricalEncoder.EXTRAS_SEPARATOR)
228
-
229
- for key, value in extras.items():
230
- f.write(repr(key) + CategoricalEncoder.VALUE_SEPARATOR + repr(value) + "\n")
231
-
232
- f.flush()
233
-
234
- @staticmethod
235
- def _load_literal(path):
236
- lab2ind, ind2lab, extras = {}, {}, {}
237
-
238
- with open(path, encoding="utf-8") as f:
239
- for line in f:
240
- if line == CategoricalEncoder.EXTRAS_SEPARATOR: break
241
- literal, ind = line.strip().split(CategoricalEncoder.VALUE_SEPARATOR, maxsplit=1)
242
- label = ast.literal_eval(literal)
243
- lab2ind[label] = int(ind)
244
- ind2lab[ind] = label
245
-
246
- for line in f:
247
- literal_key, literal_value = line.strip().split(CategoricalEncoder.VALUE_SEPARATOR, maxsplit=1)
248
- extras[ast.literal_eval(literal_key)] = ast.literal_eval(literal_value)
249
-
250
- return lab2ind, ind2lab, extras
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/speaker_diarization/features.py DELETED
@@ -1,520 +0,0 @@
1
- import os
2
- import sys
3
- import math
4
- import torch
5
- import inspect
6
- import functools
7
-
8
- sys.path.append(os.getcwd())
9
-
10
- from main.library.speaker_diarization.speechbrain import MAIN_PROC_ONLY, is_distributed_initialized, main_process_only
11
-
12
- KEYS_MAPPING = {".mutihead_attn": ".multihead_attn", ".convs_intermedite": ".convs_intermediate"}
13
-
14
- def map_old_state_dict_weights(state_dict, mapping):
15
- for replacement_old, replacement_new in mapping.items():
16
- for old_key in list(state_dict.keys()):
17
- if replacement_old in old_key: state_dict[old_key.replace(replacement_old, replacement_new)] = state_dict.pop(old_key)
18
-
19
- return state_dict
20
-
21
- def hook_on_loading_state_dict_checkpoint(state_dict):
22
- return map_old_state_dict_weights(state_dict, KEYS_MAPPING)
23
-
24
- def torch_patched_state_dict_load(path, device="cpu"):
25
- return hook_on_loading_state_dict_checkpoint(torch.load(path, map_location=device))
26
-
27
- @main_process_only
28
- def torch_save(obj, path):
29
- state_dict = obj.state_dict()
30
- torch.save(state_dict, path)
31
-
32
- def torch_recovery(obj, path, end_of_epoch):
33
- del end_of_epoch
34
-
35
- state_dict = torch_patched_state_dict_load(path, "cpu")
36
- try:
37
- obj.load_state_dict(state_dict, strict=True)
38
- except TypeError:
39
- obj.load_state_dict(state_dict)
40
-
41
- def torch_parameter_transfer(obj, path):
42
- incompatible_keys = obj.load_state_dict(torch_patched_state_dict_load(path, "cpu"), strict=False)
43
-
44
- for missing_key in incompatible_keys.missing_keys:
45
- pass
46
- for unexpected_key in incompatible_keys.unexpected_keys:
47
- pass
48
-
49
- WEAKREF_MARKER = "WEAKREF"
50
-
51
- def _cycliclrsaver(obj, path):
52
- state_dict = obj.state_dict()
53
- if state_dict.get("_scale_fn_ref") is not None: state_dict["_scale_fn_ref"] = WEAKREF_MARKER
54
-
55
- torch.save(state_dict, path)
56
-
57
- def _cycliclrloader(obj, path, end_of_epoch):
58
- del end_of_epoch
59
-
60
- try:
61
- obj.load_state_dict(torch.load(path, map_location="cpu"), strict=True)
62
- except TypeError:
63
- obj.load_state_dict(torch.load(path, map_location="cpu"))
64
-
65
- DEFAULT_LOAD_HOOKS = {torch.nn.Module: torch_recovery, torch.optim.Optimizer: torch_recovery, torch.optim.lr_scheduler.ReduceLROnPlateau: torch_recovery, torch.cuda.amp.grad_scaler.GradScaler: torch_recovery}
66
- DEFAULT_SAVE_HOOKS = { torch.nn.Module: torch_save, torch.optim.Optimizer: torch_save, torch.optim.lr_scheduler.ReduceLROnPlateau: torch_save, torch.cuda.amp.grad_scaler.GradScaler: torch_save}
67
- DEFAULT_LOAD_HOOKS[torch.optim.lr_scheduler.LRScheduler] = torch_recovery
68
- DEFAULT_SAVE_HOOKS[torch.optim.lr_scheduler.LRScheduler] = torch_save
69
- DEFAULT_TRANSFER_HOOKS = {torch.nn.Module: torch_parameter_transfer}
70
- DEFAULT_SAVE_HOOKS[torch.optim.lr_scheduler.CyclicLR] = _cycliclrsaver
71
- DEFAULT_LOAD_HOOKS[torch.optim.lr_scheduler.CyclicLR] = _cycliclrloader
72
-
73
- def register_checkpoint_hooks(cls, save_on_main_only=True):
74
- global DEFAULT_LOAD_HOOKS, DEFAULT_SAVE_HOOKS, DEFAULT_TRANSFER_HOOKS
75
-
76
- for name, method in cls.__dict__.items():
77
- if hasattr(method, "_speechbrain_saver"): DEFAULT_SAVE_HOOKS[cls] = main_process_only(method) if save_on_main_only else method
78
- if hasattr(method, "_speechbrain_loader"): DEFAULT_LOAD_HOOKS[cls] = method
79
- if hasattr(method, "_speechbrain_transfer"): DEFAULT_TRANSFER_HOOKS[cls] = method
80
-
81
- return cls
82
-
83
- def mark_as_saver(method):
84
- sig = inspect.signature(method)
85
-
86
- try:
87
- sig.bind(object(), os.path.join("testpath"))
88
- except TypeError:
89
- raise TypeError
90
-
91
- method._speechbrain_saver = True
92
- return method
93
-
94
- def mark_as_transfer(method):
95
- sig = inspect.signature(method)
96
-
97
- try:
98
- sig.bind(object(), os.path.join("testpath"))
99
- except TypeError:
100
- raise TypeError
101
-
102
- method._speechbrain_transfer = True
103
- return method
104
-
105
- def mark_as_loader(method):
106
- sig = inspect.signature(method)
107
-
108
- try:
109
- sig.bind(object(), os.path.join("testpath"), True)
110
- except TypeError:
111
- raise TypeError
112
-
113
- method._speechbrain_loader = True
114
- return method
115
-
116
- def ddp_all_reduce(communication_object, reduce_op):
117
- if MAIN_PROC_ONLY >= 1 or not is_distributed_initialized(): return communication_object
118
- torch.distributed.all_reduce(communication_object, op=reduce_op)
119
-
120
- return communication_object
121
-
122
- def fwd_default_precision(fwd = None, cast_inputs = torch.float32):
123
- if fwd is None: return functools.partial(fwd_default_precision, cast_inputs=cast_inputs)
124
-
125
- wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)
126
-
127
- @functools.wraps(fwd)
128
- def wrapper(*args, force_allow_autocast = False, **kwargs):
129
- return fwd(*args, **kwargs) if force_allow_autocast else wrapped_fwd(*args, **kwargs)
130
-
131
- return wrapper
132
-
133
- def spectral_magnitude(stft, power = 1, log = False, eps = 1e-14):
134
- spectr = stft.pow(2).sum(-1)
135
-
136
- if power < 1: spectr = spectr + eps
137
- spectr = spectr.pow(power)
138
-
139
- if log: return torch.log(spectr + eps)
140
- return spectr
141
-
142
- class Filterbank(torch.nn.Module):
143
- def __init__(self, n_mels=40, log_mel=True, filter_shape="triangular", f_min=0, f_max=8000, n_fft=400, sample_rate=16000, power_spectrogram=2, amin=1e-10, ref_value=1.0, top_db=80.0, param_change_factor=1.0, param_rand_factor=0.0, freeze=True):
144
- super().__init__()
145
- self.n_mels = n_mels
146
- self.log_mel = log_mel
147
- self.filter_shape = filter_shape
148
- self.f_min = f_min
149
- self.f_max = f_max
150
- self.n_fft = n_fft
151
- self.sample_rate = sample_rate
152
- self.power_spectrogram = power_spectrogram
153
- self.amin = amin
154
- self.ref_value = ref_value
155
- self.top_db = top_db
156
- self.freeze = freeze
157
- self.n_stft = self.n_fft // 2 + 1
158
- self.db_multiplier = math.log10(max(self.amin, self.ref_value))
159
- self.device_inp = torch.device("cpu")
160
- self.param_change_factor = param_change_factor
161
- self.param_rand_factor = param_rand_factor
162
- self.multiplier = 10 if self.power_spectrogram == 2 else 20
163
-
164
- hz = self._to_hz(torch.linspace(self._to_mel(self.f_min), self._to_mel(self.f_max), self.n_mels + 2))
165
-
166
- band = hz[1:] - hz[:-1]
167
- self.band = band[:-1]
168
- self.f_central = hz[1:-1]
169
-
170
- if not self.freeze:
171
- self.f_central = torch.nn.Parameter(self.f_central / (self.sample_rate * self.param_change_factor))
172
- self.band = torch.nn.Parameter(self.band / (self.sample_rate * self.param_change_factor))
173
-
174
- self.all_freqs_mat = torch.linspace(0, self.sample_rate // 2, self.n_stft).repeat(self.f_central.shape[0], 1)
175
-
176
- def forward(self, spectrogram):
177
- f_central_mat = self.f_central.repeat(self.all_freqs_mat.shape[1], 1).transpose(0, 1)
178
- band_mat = self.band.repeat(self.all_freqs_mat.shape[1], 1).transpose(0, 1)
179
-
180
- if not self.freeze:
181
- f_central_mat = f_central_mat * (self.sample_rate * self.param_change_factor * self.param_change_factor)
182
- band_mat = band_mat * (self.sample_rate * self.param_change_factor * self.param_change_factor)
183
- elif self.param_rand_factor != 0 and self.training:
184
- rand_change = (1.0 + torch.rand(2) * 2 * self.param_rand_factor - self.param_rand_factor)
185
- f_central_mat = f_central_mat * rand_change[0]
186
- band_mat = band_mat * rand_change[1]
187
-
188
- fbank_matrix = self._create_fbank_matrix(f_central_mat, band_mat).to(spectrogram.device)
189
- sp_shape = spectrogram.shape
190
- if len(sp_shape) == 4: spectrogram = spectrogram.permute(0, 3, 1, 2).reshape(sp_shape[0] * sp_shape[3], sp_shape[1], sp_shape[2])
191
-
192
- fbanks = torch.matmul(spectrogram, fbank_matrix)
193
- if self.log_mel: fbanks = self._amplitude_to_DB(fbanks)
194
-
195
- if len(sp_shape) == 4:
196
- fb_shape = fbanks.shape
197
- fbanks = fbanks.reshape(sp_shape[0], sp_shape[3], fb_shape[1], fb_shape[2]).permute(0, 2, 3, 1)
198
-
199
- return fbanks
200
-
201
- @staticmethod
202
- def _to_mel(hz):
203
- return 2595 * math.log10(1 + hz / 700)
204
-
205
- @staticmethod
206
- def _to_hz(mel):
207
- return 700 * (10 ** (mel / 2595) - 1)
208
-
209
- def _triangular_filters(self, all_freqs, f_central, band):
210
- slope = (all_freqs - f_central) / band
211
- return torch.max(torch.zeros(1, device=self.device_inp), torch.min(slope + 1.0, -slope + 1.0)).transpose(0, 1)
212
-
213
- def _rectangular_filters(self, all_freqs, f_central, band):
214
- left_side = right_size = all_freqs.ge(f_central - band)
215
- right_size = all_freqs.le(f_central + band)
216
-
217
- return (left_side * right_size).float().transpose(0, 1)
218
-
219
- def _gaussian_filters(self, all_freqs, f_central, band, smooth_factor=torch.tensor(2)):
220
- return torch.exp(-0.5 * ((all_freqs - f_central) / (band / smooth_factor)) ** 2).transpose(0, 1)
221
-
222
- def _create_fbank_matrix(self, f_central_mat, band_mat):
223
- if self.filter_shape == "triangular": fbank_matrix = self._triangular_filters(self.all_freqs_mat, f_central_mat, band_mat)
224
- elif self.filter_shape == "rectangular": fbank_matrix = self._rectangular_filters(self.all_freqs_mat, f_central_mat, band_mat)
225
- else: fbank_matrix = self._gaussian_filters(self.all_freqs_mat, f_central_mat, band_mat)
226
-
227
- return fbank_matrix
228
-
229
- def _amplitude_to_DB(self, x):
230
- x_db = self.multiplier * torch.log10(torch.clamp(x, min=self.amin))
231
- x_db -= self.multiplier * self.db_multiplier
232
-
233
- return torch.max(x_db, (x_db.amax(dim=(-2, -1)) - self.top_db).view(x_db.shape[0], 1, 1))
234
-
235
- class ContextWindow(torch.nn.Module):
236
- def __init__(self, left_frames=0, right_frames=0):
237
- super().__init__()
238
- self.left_frames = left_frames
239
- self.right_frames = right_frames
240
- self.context_len = self.left_frames + self.right_frames + 1
241
- self.kernel_len = 2 * max(self.left_frames, self.right_frames) + 1
242
- self.kernel = torch.eye(self.context_len, self.kernel_len)
243
-
244
- if self.right_frames > self.left_frames: self.kernel = torch.roll(self.kernel, self.right_frames - self.left_frames, 1)
245
- self.first_call = True
246
-
247
- def forward(self, x):
248
- x = x.transpose(1, 2)
249
- if self.first_call:
250
- self.first_call = False
251
- self.kernel = (self.kernel.repeat(x.shape[1], 1, 1).view(x.shape[1] * self.context_len, self.kernel_len).unsqueeze(1))
252
-
253
- or_shape = x.shape
254
- if len(or_shape) == 4: x = x.reshape(or_shape[0] * or_shape[2], or_shape[1], or_shape[3])
255
-
256
- cw_x = torch.nn.functional.conv1d(x, self.kernel.to(x.device), groups=x.shape[1], padding=max(self.left_frames, self.right_frames))
257
- if len(or_shape) == 4: cw_x = cw_x.reshape(or_shape[0], cw_x.shape[1], or_shape[2], cw_x.shape[-1])
258
-
259
- return cw_x.transpose(1, 2)
260
-
261
- class FilterProperties:
262
- def __init__(self, window_size = 0, stride = 1, dilation = 1, causal = False):
263
- self.window_size = window_size
264
- self.stride = stride
265
- self.dilation = dilation
266
- self.causal = causal
267
-
268
- def __post_init__(self):
269
- assert self.window_size > 0
270
- assert self.stride > 0
271
- assert (self.dilation > 0)
272
-
273
- @staticmethod
274
- def pointwise_filter():
275
- return FilterProperties(window_size=1, stride=1)
276
-
277
- def get_effective_size(self):
278
- return 1 + ((self.window_size - 1) * self.dilation)
279
-
280
- def get_convolution_padding(self):
281
- if self.window_size % 2 == 0: raise ValueError
282
- if self.causal: return self.get_effective_size() - 1
283
-
284
- return (self.get_effective_size() - 1) // 2
285
-
286
- def get_noncausal_equivalent(self):
287
- if not self.causal: return self
288
- return FilterProperties(window_size=(self.window_size - 1) * 2 + 1, stride=self.stride, dilation=self.dilation, causal=False)
289
-
290
- def with_on_top(self, other, allow_approximate=True):
291
- self_size = self.window_size
292
-
293
- if other.window_size % 2 == 0:
294
- if allow_approximate: other_size = other.window_size + 1
295
- else: raise ValueError
296
- else: other_size = other.window_size
297
-
298
- if (self.causal or other.causal) and not (self.causal and other.causal):
299
- if allow_approximate: return self.get_noncausal_equivalent().with_on_top(other.get_noncausal_equivalent())
300
- else: raise ValueError
301
-
302
- return FilterProperties(self_size + (self.stride * (other_size - 1)), self.stride * other.stride, self.dilation * other.dilation, self.causal)
303
-
304
- class STFT(torch.nn.Module):
305
- def __init__(self, sample_rate, win_length=25, hop_length=10, n_fft=400, window_fn=torch.hamming_window, normalized_stft=False, center=True, pad_mode="constant", onesided=True):
306
- super().__init__()
307
- self.sample_rate = sample_rate
308
- self.win_length = win_length
309
- self.hop_length = hop_length
310
- self.n_fft = n_fft
311
- self.normalized_stft = normalized_stft
312
- self.center = center
313
- self.pad_mode = pad_mode
314
- self.onesided = onesided
315
- self.win_length = int(round((self.sample_rate / 1000.0) * self.win_length))
316
- self.hop_length = int(round((self.sample_rate / 1000.0) * self.hop_length))
317
- self.window = window_fn(self.win_length)
318
-
319
- def forward(self, x):
320
- or_shape = x.shape
321
- if len(or_shape) == 3: x = x.transpose(1, 2).reshape(or_shape[0] * or_shape[2], or_shape[1])
322
-
323
- stft = torch.view_as_real(torch.stft(x, self.n_fft, self.hop_length, self.win_length, self.window.to(x.device), self.center, self.pad_mode, self.normalized_stft, self.onesided, return_complex=True))
324
- stft = stft.reshape(or_shape[0], or_shape[2], stft.shape[1], stft.shape[2], stft.shape[3]).permute(0, 3, 2, 4, 1) if len(or_shape) == 3 else stft.transpose(2, 1)
325
-
326
- return stft
327
-
328
- def get_filter_properties(self):
329
- if not self.center: raise ValueError
330
- return FilterProperties(window_size=self.win_length, stride=self.hop_length)
331
-
332
- class Deltas(torch.nn.Module):
333
- def __init__(self, input_size, window_length=5):
334
- super().__init__()
335
- self.n = (window_length - 1) // 2
336
- self.denom = self.n * (self.n + 1) * (2 * self.n + 1) / 3
337
- self.register_buffer("kernel", torch.arange(-self.n, self.n + 1, dtype=torch.float32).repeat(input_size, 1, 1),)
338
-
339
- def forward(self, x):
340
- x = x.transpose(1, 2).transpose(2, -1)
341
- or_shape = x.shape
342
-
343
- if len(or_shape) == 4: x = x.reshape(or_shape[0] * or_shape[2], or_shape[1], or_shape[3])
344
-
345
- x = torch.nn.functional.pad(x, (self.n, self.n), mode="replicate")
346
- delta_coeff = (torch.nn.functional.conv1d(x, self.kernel.to(x.device), groups=x.shape[1]) / self.denom)
347
-
348
- if len(or_shape) == 4: delta_coeff = delta_coeff.reshape(or_shape[0], or_shape[1], or_shape[2], or_shape[3])
349
- return delta_coeff.transpose(1, -1).transpose(2, -1)
350
-
351
- class Fbank(torch.nn.Module):
352
- def __init__(self, deltas=False, context=False, requires_grad=False, sample_rate=16000, f_min=0, f_max=None, n_fft=400, n_mels=40, filter_shape="triangular", param_change_factor=1.0, param_rand_factor=0.0, left_frames=5, right_frames=5, win_length=25, hop_length=10):
353
- super().__init__()
354
- self.deltas = deltas
355
- self.context = context
356
- self.requires_grad = requires_grad
357
- if f_max is None: f_max = sample_rate / 2
358
- self.compute_STFT = STFT(sample_rate=sample_rate,n_fft=n_fft,win_length=win_length,hop_length=hop_length)
359
- self.compute_fbanks = Filterbank(sample_rate=sample_rate,n_fft=n_fft,n_mels=n_mels,f_min=f_min,f_max=f_max,freeze=not requires_grad,filter_shape=filter_shape,param_change_factor=param_change_factor,param_rand_factor=param_rand_factor)
360
- self.compute_deltas = Deltas(input_size=n_mels)
361
- self.context_window = ContextWindow(left_frames=left_frames, right_frames=right_frames)
362
-
363
- @fwd_default_precision(cast_inputs=torch.float32)
364
- def forward(self, wav):
365
- fbanks = self.compute_fbanks(spectral_magnitude(self.compute_STFT(wav)))
366
- if self.deltas:
367
- delta1 = self.compute_deltas(fbanks)
368
- fbanks = torch.cat([fbanks, delta1, self.compute_deltas(delta1)], dim=2)
369
-
370
- if self.context: fbanks = self.context_window(fbanks)
371
- return fbanks
372
-
373
- def get_filter_properties(self):
374
- return self.compute_STFT.get_filter_properties()
375
-
376
- @register_checkpoint_hooks
377
- class InputNormalization(torch.nn.Module):
378
- def __init__(self, mean_norm=True, std_norm=True, norm_type="global", avg_factor=None, requires_grad=False, update_until_epoch=3):
379
- super().__init__()
380
- self.mean_norm = mean_norm
381
- self.std_norm = std_norm
382
- self.norm_type = norm_type
383
- self.avg_factor = avg_factor
384
- self.requires_grad = requires_grad
385
- self.glob_mean = torch.tensor([0])
386
- self.glob_std = torch.tensor([0])
387
- self.spk_dict_mean = {}
388
- self.spk_dict_std = {}
389
- self.spk_dict_count = {}
390
- self.weight = 1.0
391
- self.count = 0
392
- self.eps = 1e-10
393
- self.update_until_epoch = update_until_epoch
394
-
395
- def forward(self, x, lengths, spk_ids = torch.tensor([]), epoch=0):
396
- N_batches = x.shape[0]
397
- current_means, current_stds = [], []
398
-
399
- if self.norm_type == "sentence" or self.norm_type == "speaker": out = torch.empty_like(x)
400
-
401
- for snt_id in range(N_batches):
402
- actual_size = torch.round(lengths[snt_id] * x.shape[1]).int()
403
- current_mean, current_std = self._compute_current_stats(x[snt_id, 0:actual_size, ...])
404
-
405
- current_means.append(current_mean)
406
- current_stds.append(current_std)
407
-
408
- if self.norm_type == "sentence": out[snt_id] = (x[snt_id] - current_mean.data) / current_std.data
409
-
410
- if self.norm_type == "speaker":
411
- spk_id = int(spk_ids[snt_id][0])
412
-
413
- if self.training:
414
- if spk_id not in self.spk_dict_mean:
415
- self.spk_dict_mean[spk_id] = current_mean
416
- self.spk_dict_std[spk_id] = current_std
417
- self.spk_dict_count[spk_id] = 1
418
- else:
419
- self.spk_dict_count[spk_id] = (self.spk_dict_count[spk_id] + 1)
420
- self.weight = (1 / self.spk_dict_count[spk_id]) if self.avg_factor is None else self.avg_factor
421
-
422
- self.spk_dict_mean[spk_id] = (1 - self.weight) * self.spk_dict_mean[spk_id].to(current_mean) + self.weight * current_mean
423
- self.spk_dict_std[spk_id] = (1 - self.weight) * self.spk_dict_std[spk_id].to(current_std) + self.weight * current_std
424
-
425
- self.spk_dict_mean[spk_id].detach()
426
- self.spk_dict_std[spk_id].detach()
427
-
428
- speaker_mean = self.spk_dict_mean[spk_id].data
429
- speaker_std = self.spk_dict_std[spk_id].data
430
- else:
431
- if spk_id in self.spk_dict_mean:
432
- speaker_mean = self.spk_dict_mean[spk_id].data
433
- speaker_std = self.spk_dict_std[spk_id].data
434
- else:
435
- speaker_mean = current_mean.data
436
- speaker_std = current_std.data
437
-
438
- out[snt_id] = (x[snt_id] - speaker_mean) / speaker_std
439
-
440
- if self.norm_type == "batch" or self.norm_type == "global":
441
- current_mean = ddp_all_reduce(torch.mean(torch.stack(current_means), dim=0), torch.distributed.ReduceOp.AVG)
442
- current_std = ddp_all_reduce(torch.mean(torch.stack(current_stds), dim=0), torch.distributed.ReduceOp.AVG)
443
-
444
- if self.norm_type == "batch": out = (x - current_mean.data) / (current_std.data)
445
-
446
- if self.norm_type == "global":
447
- if self.training:
448
- if self.count == 0:
449
- self.glob_mean = current_mean
450
- self.glob_std = current_std
451
- elif epoch is None or epoch < self.update_until_epoch:
452
- self.weight = (1 / (self.count + 1)) if self.avg_factor is None else self.avg_factor
453
- self.glob_mean = (1 - self.weight) * self.glob_mean.to(current_mean) + self.weight * current_mean
454
- self.glob_std = (1 - self.weight) * self.glob_std.to(current_std) + self.weight * current_std
455
-
456
- self.glob_mean.detach()
457
- self.glob_std.detach()
458
- self.count = self.count + 1
459
-
460
- out = (x - self.glob_mean.data.to(x)) / (self.glob_std.data.to(x))
461
-
462
- return out
463
-
464
- def _compute_current_stats(self, x):
465
- current_std = torch.std(x, dim=0).detach().data if self.std_norm else torch.tensor([1.0], device=x.device)
466
- return torch.mean(x, dim=0).detach().data if self.mean_norm else torch.tensor([0.0], device=x.device), torch.max(current_std, self.eps * torch.ones_like(current_std))
467
-
468
- def _statistics_dict(self):
469
- state = {}
470
- state["count"] = self.count
471
- state["glob_mean"] = self.glob_mean
472
- state["glob_std"] = self.glob_std
473
- state["spk_dict_mean"] = self.spk_dict_mean
474
- state["spk_dict_std"] = self.spk_dict_std
475
- state["spk_dict_count"] = self.spk_dict_count
476
-
477
- return state
478
-
479
- def _load_statistics_dict(self, state):
480
- self.count = state["count"]
481
-
482
- if isinstance(state["glob_mean"], int):
483
- self.glob_mean = state["glob_mean"]
484
- self.glob_std = state["glob_std"]
485
- else:
486
- self.glob_mean = state["glob_mean"]
487
- self.glob_std = state["glob_std"]
488
-
489
- self.spk_dict_mean = {}
490
- for spk in state["spk_dict_mean"]:
491
- self.spk_dict_mean[spk] = state["spk_dict_mean"][spk]
492
-
493
- self.spk_dict_std = {}
494
- for spk in state["spk_dict_std"]:
495
- self.spk_dict_std[spk] = state["spk_dict_std"][spk]
496
-
497
- self.spk_dict_count = state["spk_dict_count"]
498
- return state
499
-
500
- def to(self, device):
501
- self = super(InputNormalization, self).to(device)
502
- self.glob_mean = self.glob_mean.to(device)
503
- self.glob_std = self.glob_std.to(device)
504
-
505
- for spk in self.spk_dict_mean:
506
- self.spk_dict_mean[spk] = self.spk_dict_mean[spk].to(device)
507
- self.spk_dict_std[spk] = self.spk_dict_std[spk].to(device)
508
-
509
- return self
510
-
511
- @mark_as_saver
512
- def _save(self, path):
513
- torch.save(self._statistics_dict(), path)
514
-
515
- @mark_as_transfer
516
- @mark_as_loader
517
- def _load(self, path, end_of_epoch=False):
518
- del end_of_epoch
519
- stats = torch.load(path, map_location="cpu")
520
- self._load_statistics_dict(stats)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/speaker_diarization/parameter_transfer.py DELETED
@@ -1,120 +0,0 @@
1
- import os
2
- import sys
3
- import inspect
4
-
5
- sys.path.append(os.getcwd())
6
-
7
- from main.library.speaker_diarization.speechbrain import fetch, run_on_main
8
- from main.library.speaker_diarization.features import DEFAULT_TRANSFER_HOOKS, DEFAULT_LOAD_HOOKS
9
-
10
-
11
- def get_default_hook(obj, default_hooks):
12
- for cls in inspect.getmro(type(obj)):
13
- if cls in default_hooks: return default_hooks[cls]
14
-
15
- return None
16
-
17
- class Pretrainer:
18
- def __init__(self, loadables=None, paths=None, custom_hooks=None, conditions=None):
19
- self.loadables = {}
20
-
21
- if loadables is not None: self.add_loadables(loadables)
22
- self.paths = {}
23
-
24
- if paths is not None: self.add_paths(paths)
25
- self.custom_hooks = {}
26
-
27
- if custom_hooks is not None: self.add_custom_hooks(custom_hooks)
28
- self.conditions = {}
29
-
30
- if conditions is not None: self.add_conditions(conditions)
31
- self.is_local = []
32
-
33
- def add_loadables(self, loadables):
34
- self.loadables.update(loadables)
35
-
36
- def add_paths(self, paths):
37
- self.paths.update(paths)
38
-
39
- def add_custom_hooks(self, custom_hooks):
40
- self.custom_hooks.update(custom_hooks)
41
-
42
- def add_conditions(self, conditions):
43
- self.conditions.update(conditions)
44
-
45
- @staticmethod
46
- def split_path(path):
47
- def split(src):
48
- if "/" in src: return src.rsplit("/", maxsplit=1)
49
- else: return "./", src
50
-
51
- return split(path)
52
-
53
- def collect_files(self, default_source=None):
54
- loadable_paths = {}
55
- for name in self.loadables:
56
- if not self.is_loadable(name): continue
57
- save_filename = name + ".ckpt"
58
-
59
- if name in self.paths: source, filename = self.split_path(self.paths[name])
60
- elif default_source is not None:
61
- filename = save_filename
62
- source = default_source
63
- else: raise ValueError
64
-
65
- fetch_kwargs = {"filename": filename, "source": source}
66
- path = None
67
-
68
- def run_fetch(**kwargs):
69
- nonlocal path
70
-
71
- path = fetch(**kwargs)
72
-
73
- run_on_main(run_fetch, kwargs=fetch_kwargs, post_func=run_fetch, post_kwargs=fetch_kwargs)
74
-
75
- loadable_paths[name] = path
76
- self.paths[name] = str(path)
77
- self.is_local.append(name)
78
-
79
- return loadable_paths
80
-
81
- def is_loadable(self, name):
82
- if name not in self.conditions: return True
83
- condition = self.conditions[name]
84
-
85
- if callable(condition): return condition()
86
- else: return bool(condition)
87
-
88
- def load_collected(self):
89
- paramfiles = {}
90
- for name in self.loadables:
91
- if not self.is_loadable(name): continue
92
-
93
- if name in self.is_local: paramfiles[name] = self.paths[name]
94
- else: raise ValueError
95
-
96
- self._call_load_hooks(paramfiles)
97
-
98
- def _call_load_hooks(self, paramfiles):
99
- for name, obj in self.loadables.items():
100
- if not self.is_loadable(name): continue
101
- loadpath = paramfiles[name]
102
-
103
- if name in self.custom_hooks:
104
- self.custom_hooks[name](obj, loadpath)
105
- continue
106
-
107
- default_hook = get_default_hook(obj, DEFAULT_TRANSFER_HOOKS)
108
-
109
- if default_hook is not None:
110
- default_hook(obj, loadpath)
111
- continue
112
-
113
- default_hook = get_default_hook(obj, DEFAULT_LOAD_HOOKS)
114
-
115
- if default_hook is not None:
116
- end_of_epoch = False
117
- default_hook(obj, loadpath, end_of_epoch)
118
- continue
119
-
120
- raise RuntimeError