rvc_api / modules /tabs /training.py
aryo100's picture
first commit
b5a064f
import math
import os
import shutil
from multiprocessing import cpu_count
import gradio as gr
from lib.rvc.preprocessing import extract_f0, extract_feature, split
from lib.rvc.train import create_dataset_meta, glob_dataset, train_index, train_model
from modules import models, utils
from modules.shared import MODELS_DIR, device, half_support
from modules.ui import Tab
SR_DICT = {
"32k": 32000,
"40k": 40000,
"48k": 48000,
}
class Training(Tab):
def title(self):
return "Training"
def sort(self):
return 2
def ui(self, outlet):
def train_index_only(
model_name,
target_sr,
f0,
dataset_glob,
recursive,
multiple_speakers,
speaker_id,
gpu_id,
num_cpu_process,
norm_audio_when_preprocess,
pitch_extraction_algo,
run_train_index,
reduce_index_size,
maximum_index_size,
embedder_name,
embedding_channels,
embedding_output_layer,
ignore_cache,
):
maximum_index_size = int(maximum_index_size)
f0 = f0 == "Yes"
norm_audio_when_preprocess = norm_audio_when_preprocess == "Yes"
run_train_index = run_train_index == "Yes"
reduce_index_size = reduce_index_size == "Yes"
training_dir = os.path.join(MODELS_DIR, "training", "models", model_name)
gpu_ids = [int(x.strip()) for x in gpu_id.split(",")] if gpu_id else []
yield f"Training directory: {training_dir}"
if os.path.exists(training_dir) and ignore_cache:
shutil.rmtree(training_dir)
os.makedirs(training_dir, exist_ok=True)
datasets = glob_dataset(
dataset_glob,
speaker_id,
multiple_speakers=multiple_speakers,
recursive=recursive,
)
if len(datasets) == 0:
raise Exception("No audio files found")
yield "Preprocessing..."
split.preprocess_audio(
datasets,
SR_DICT[target_sr],
num_cpu_process,
training_dir,
norm_audio_when_preprocess,
os.path.join(
MODELS_DIR,
"training",
"mute",
"0_gt_wavs",
f"mute{target_sr}.wav",
),
)
if f0:
yield "Extracting f0..."
extract_f0.run(training_dir, num_cpu_process, pitch_extraction_algo)
yield "Extracting features..."
embedder_filepath, _, embedder_load_from = models.get_embedder(
embedder_name
)
if embedder_load_from == "local":
embedder_filepath = os.path.join(
MODELS_DIR, "embeddings", embedder_filepath
)
extract_feature.run(
training_dir,
embedder_filepath,
embedder_load_from,
int(embedding_channels),
int(embedding_output_layer),
gpu_ids,
)
out_dir = os.path.join(MODELS_DIR, "checkpoints")
yield "Training index..."
if run_train_index:
if not reduce_index_size:
maximum_index_size = None
train_index(
training_dir,
model_name,
out_dir,
int(embedding_channels),
num_cpu_process,
maximum_index_size,
)
yield "Training complete"
def train_all(
model_name,
version,
sampling_rate_str,
f0,
dataset_glob,
recursive,
multiple_speakers,
speaker_id,
gpu_id,
num_cpu_process,
norm_audio_when_preprocess,
pitch_extraction_algo,
batch_size,
augment,
augment_from_pretrain,
augment_path,
speaker_info_path,
cache_batch,
num_epochs,
save_every_epoch,
save_wav_with_checkpoint,
fp16,
save_only_last,
pre_trained_bottom_model_g,
pre_trained_bottom_model_d,
run_train_index,
reduce_index_size,
maximum_index_size,
embedder_name,
embedding_channels,
embedding_output_layer,
ignore_cache,
):
batch_size = int(batch_size)
num_epochs = int(num_epochs)
maximum_index_size = int(maximum_index_size)
f0 = f0 == "Yes"
norm_audio_when_preprocess = norm_audio_when_preprocess == "Yes"
run_train_index = run_train_index == "Yes"
reduce_index_size = reduce_index_size == "Yes"
training_dir = os.path.join(MODELS_DIR, "training", "models", model_name)
gpu_ids = [int(x.strip()) for x in gpu_id.split(",")] if gpu_id else []
if os.path.exists(training_dir) and ignore_cache:
shutil.rmtree(training_dir)
os.makedirs(training_dir, exist_ok=True)
yield f"Training directory: {training_dir}"
datasets = glob_dataset(
dataset_glob,
speaker_id,
multiple_speakers=multiple_speakers,
recursive=recursive,
training_dir=training_dir,
)
if len(datasets) == 0:
raise Exception("No audio files found")
yield "Preprocessing..."
split.preprocess_audio(
datasets,
SR_DICT[sampling_rate_str],
num_cpu_process,
training_dir,
norm_audio_when_preprocess,
os.path.join(
MODELS_DIR,
"training",
"mute",
"0_gt_wavs",
f"mute{sampling_rate_str}.wav",
),
)
if f0:
yield "Extracting f0..."
extract_f0.run(training_dir, num_cpu_process, pitch_extraction_algo)
yield "Extracting features..."
embedder_filepath, _, embedder_load_from = models.get_embedder(
embedder_name
)
if embedder_load_from == "local":
embedder_filepath = os.path.join(
MODELS_DIR, "embeddings", embedder_filepath
)
extract_feature.run(
training_dir,
embedder_filepath,
embedder_load_from,
int(embedding_channels),
int(embedding_output_layer),
gpu_ids,
None if len(gpu_ids) > 1 else device,
)
create_dataset_meta(training_dir, f0)
yield "Training model..."
print(f"train_all: emb_name: {embedder_name}")
config = utils.load_config(
version, training_dir, sampling_rate_str, embedding_channels, fp16
)
out_dir = os.path.join(MODELS_DIR, "checkpoints")
if not augment_from_pretrain:
augment_path = None
speaker_info_path = None
train_model(
gpu_ids,
config,
training_dir,
model_name,
out_dir,
sampling_rate_str,
f0,
batch_size,
augment,
augment_path,
speaker_info_path,
cache_batch,
num_epochs,
save_every_epoch,
save_wav_with_checkpoint,
pre_trained_bottom_model_g,
pre_trained_bottom_model_d,
embedder_name,
int(embedding_output_layer),
save_only_last,
None if len(gpu_ids) > 1 else device,
)
yield "Training index..."
if run_train_index:
if not reduce_index_size:
maximum_index_size = None
train_index(
training_dir,
model_name,
out_dir,
int(embedding_channels),
num_cpu_process,
maximum_index_size,
)
yield "Training completed"
with gr.Group():
with gr.Box():
with gr.Column():
with gr.Row():
with gr.Column():
model_name = gr.Textbox(label="Model Name")
ignore_cache = gr.Checkbox(label="Ignore cache")
with gr.Column():
dataset_glob = gr.Textbox(
label="Dataset glob", placeholder="data/**/*.wav"
)
recursive = gr.Checkbox(label="Recursive", value=True)
multiple_speakers = gr.Checkbox(
label="Multiple speakers", value=False
)
speaker_id = gr.Slider(
maximum=4,
minimum=0,
value=0,
step=1,
label="Speaker ID",
)
with gr.Row(equal_height=False):
version = gr.Radio(
choices=["v1", "v2"],
value="v2",
label="Model version",
)
target_sr = gr.Radio(
choices=["32k", "40k", "48k"],
value="40k",
label="Target sampling rate",
)
f0 = gr.Radio(
choices=["Yes", "No"],
value="Yes",
label="f0 Model",
)
with gr.Row(equal_height=False):
embedding_name = gr.Radio(
choices=list(models.EMBEDDINGS_LIST.keys()),
value="contentvec",
label="Using phone embedder",
)
embedding_channels = gr.Radio(
choices=["256", "768"],
value="768",
label="Embedding channels",
)
embedding_output_layer = gr.Radio(
choices=["9", "12"],
value="12",
label="Embedding output layer",
)
with gr.Row(equal_height=False):
gpu_id = gr.Textbox(
label="GPU ID",
value=", ".join([f"{x.index}" for x in utils.get_gpus()]),
)
num_cpu_process = gr.Slider(
minimum=0,
maximum=cpu_count(),
step=1,
value=math.ceil(cpu_count() / 2),
label="Number of CPU processes",
)
norm_audio_when_preprocess = gr.Radio(
choices=["Yes", "No"],
value="Yes",
label="Normalize audio volume when preprocess",
)
pitch_extraction_algo = gr.Radio(
choices=["dio", "harvest", "mangio-crepe", "crepe"],
value="crepe",
label="Pitch extraction algorithm",
)
with gr.Row(equal_height=False):
batch_size = gr.Number(value=4, label="Batch size")
num_epochs = gr.Number(
value=30,
label="Number of epochs",
)
save_every_epoch = gr.Slider(
minimum=0,
maximum=100,
value=10,
step=1,
label="Save every epoch",
)
save_wav_with_checkpoint = gr.Checkbox(
label="save_wav_with_checkpoint", value=False
)
cache_batch = gr.Checkbox(label="Cache batch", value=True)
fp16 = gr.Checkbox(
label="FP16", value=half_support, disabled=not half_support
)
save_only_last = gr.Checkbox(
label="Save only the latest G and D files", value=False
)
with gr.Row(equal_height=False):
augment = gr.Checkbox(label="Augment", value=False)
augment_from_pretrain = gr.Checkbox(
label="Augment From Pretrain", value=False
)
augment_path = gr.Textbox(
label="Pre trained generator path (pth)",
value="file is not prepared",
)
speaker_info_path = gr.Textbox(
label="speaker info path (npy)",
value="file is not prepared",
)
with gr.Row(equal_height=False):
pre_trained_generator = gr.Textbox(
label="Pre trained generator path",
value=os.path.join(
MODELS_DIR, "pretrained", "v2", "f0G40k.pth"
),
)
pre_trained_discriminator = gr.Textbox(
label="Pre trained discriminator path",
value=os.path.join(
MODELS_DIR, "pretrained", "v2", "f0D40k.pth"
),
)
with gr.Row(equal_height=False):
run_train_index = gr.Radio(
choices=["Yes", "No"],
value="Yes",
label="Train Index",
)
reduce_index_size = gr.Radio(
choices=["Yes", "No"],
value="No",
label="Reduce index size with kmeans",
)
maximum_index_size = gr.Number(
value=10000, label="maximum index size"
)
with gr.Row(equal_height=False):
status = gr.Textbox(value="", label="Status")
with gr.Row(equal_height=False):
train_index_button = gr.Button("Train Index", variant="primary")
train_all_button = gr.Button("Train", variant="primary")
train_index_button.click(
train_index_only,
inputs=[
model_name,
target_sr,
f0,
dataset_glob,
recursive,
multiple_speakers,
speaker_id,
gpu_id,
num_cpu_process,
norm_audio_when_preprocess,
pitch_extraction_algo,
run_train_index,
reduce_index_size,
maximum_index_size,
embedding_name,
embedding_channels,
embedding_output_layer,
ignore_cache,
],
outputs=[status],
)
train_all_button.click(
train_all,
inputs=[
model_name,
version,
target_sr,
f0,
dataset_glob,
recursive,
multiple_speakers,
speaker_id,
gpu_id,
num_cpu_process,
norm_audio_when_preprocess,
pitch_extraction_algo,
batch_size,
augment,
augment_from_pretrain,
augment_path,
speaker_info_path,
cache_batch,
num_epochs,
save_every_epoch,
save_wav_with_checkpoint,
fp16,
save_only_last,
pre_trained_generator,
pre_trained_discriminator,
run_train_index,
reduce_index_size,
maximum_index_size,
embedding_name,
embedding_channels,
embedding_output_layer,
ignore_cache,
],
outputs=[status],
)