cfm_svc / prepare /preprocess_speaker.py
Hector Li
Initial commit for Hugging Face
df93d13
import sys,os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import numpy as np
import argparse
import multiprocessing
from functools import partial
from multiprocessing.pool import ThreadPool
from tqdm import tqdm
from tqdm import tqdm
from argparse import RawTextHelpFormatter
from speaker.models.lstm import LSTMSpeakerEncoder
from speaker.config import SpeakerEncoderConfig
from speaker.utils.audio import AudioProcessor
from speaker.infer import read_json
def str2bool(value):
if isinstance(value, bool):
return value
value = value.strip().lower()
if value in {"1", "true", "yes", "y", "on"}:
return True
if value in {"0", "false", "no", "n", "off"}:
return False
raise argparse.ArgumentTypeError("Expected a boolean value.")
def get_spk_wavs(dataset_path, output_path):
wav_files = []
if not os.path.exists(f"./{output_path}"):
os.makedirs(f"./{output_path}")
for spks in os.listdir(dataset_path):
if os.path.isdir(f"./{dataset_path}/{spks}"):
if not os.path.exists(f"./{output_path}/{spks}"):
os.makedirs(f"./{output_path}/{spks}")
for file in os.listdir(f"./{dataset_path}/{spks}"):
if file.endswith(".wav"):
wav_files.append(f"./{dataset_path}/{spks}/{file}")
elif spks.endswith(".wav"):
wav_files.append(f"./{dataset_path}/{spks}")
return wav_files
def process_wav(wav_file, dataset_path, output_path, args, speaker_encoder_ap, speaker_encoder):
waveform = speaker_encoder_ap.load_wav(
wav_file, sr=speaker_encoder_ap.sample_rate
)
spec = speaker_encoder_ap.melspectrogram(waveform)
spec = torch.from_numpy(spec.T)
if args.use_cuda:
spec = spec.cuda()
spec = spec.unsqueeze(0)
embed = speaker_encoder.compute_embedding(spec).detach().cpu().numpy()
embed = embed.squeeze()
embed_path = wav_file.replace(dataset_path, output_path)
embed_path = embed_path.replace(".wav", ".spk")
np.save(embed_path, embed, allow_pickle=False)
def extract_speaker_embeddings(wav_files, dataset_path, output_path, args, speaker_encoder_ap, speaker_encoder, concurrency):
bound_process_wav = partial(process_wav, dataset_path=dataset_path, output_path=output_path, args=args, speaker_encoder_ap=speaker_encoder_ap, speaker_encoder=speaker_encoder)
with ThreadPool(concurrency) as pool:
list(tqdm(pool.imap(bound_process_wav, wav_files), total=len(wav_files)))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""Compute embedding vectors for each wav file in a dataset.""",
formatter_class=RawTextHelpFormatter,
)
parser.add_argument("dataset_path", type=str, help="Path to dataset waves.")
parser.add_argument(
"output_path", type=str, help="path for output speaker/speaker_wavs.npy."
)
parser.add_argument("--use_cuda", type=str2bool, nargs="?", const=True, default=None, help="flag to set cuda")
parser.add_argument("-t", "--thread_count", help="thread count to process, set 0 to use all cpu cores", dest="thread_count", type=int, default=1)
args = parser.parse_args()
dataset_path = args.dataset_path
output_path = args.output_path
thread_count = args.thread_count
if args.use_cuda is None:
args.use_cuda = torch.cuda.is_available()
if args.use_cuda and not torch.cuda.is_available():
print("CUDA requested but not available. Falling back to CPU for speaker encoder.")
args.use_cuda = False
# We will let Speaker Encoder run on CPU for now since its LSTM implementation
# may have hardcoded .cuda() calls, but CREPE is the main 10 min bottleneck.
print(f"use_cuda={args.use_cuda}")
# model
args.model_path = os.path.join("speaker_pretrain", "best_model.pth.tar")
args.config_path = os.path.join("speaker_pretrain", "config.json")
# config
config_dict = read_json(args.config_path)
# model
config = SpeakerEncoderConfig(config_dict)
config.from_dict(config_dict)
speaker_encoder = LSTMSpeakerEncoder(
config.model_params["input_dim"],
config.model_params["proj_dim"],
config.model_params["lstm_dim"],
config.model_params["num_lstm_layers"],
)
speaker_encoder.load_checkpoint(args.model_path, eval=True, use_cuda=args.use_cuda)
# preprocess
speaker_encoder_ap = AudioProcessor(**config.audio)
# normalize the input audio level and trim silences
speaker_encoder_ap.do_sound_norm = True
speaker_encoder_ap.do_trim_silence = True
wav_files = get_spk_wavs(dataset_path, output_path)
if thread_count == 0:
process_num = os.cpu_count()
else:
process_num = thread_count
extract_speaker_embeddings(wav_files, dataset_path, output_path, args, speaker_encoder_ap, speaker_encoder, process_num)