|
|
import multiprocessing as mp |
|
|
import os |
|
|
import traceback |
|
|
from concurrent.futures import ProcessPoolExecutor |
|
|
from typing import * |
|
|
|
|
|
import numpy as np |
|
|
import soundfile as sf |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from fairseq import checkpoint_utils |
|
|
from tqdm import tqdm |
|
|
|
|
|
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) |
|
|
MODELS_DIR = os.path.join(ROOT_DIR, "models") |
|
|
EMBEDDINGS_LIST = { |
|
|
"hubert-base-japanese": ( |
|
|
"rinna_hubert_base_jp.pt", |
|
|
"hubert-base-japanese", |
|
|
"local", |
|
|
), |
|
|
"contentvec": ("checkpoint_best_legacy_500.pt", "contentvec", "local"), |
|
|
} |
|
|
|
|
|
def get_embedder(embedder_name): |
|
|
if embedder_name in EMBEDDINGS_LIST: |
|
|
return EMBEDDINGS_LIST[embedder_name] |
|
|
return None |
|
|
|
|
|
|
|
|
def load_embedder(embedder_path: str, device): |
|
|
try: |
|
|
models, cfg, _ = checkpoint_utils.load_model_ensemble_and_task( |
|
|
[embedder_path], |
|
|
suffix="", |
|
|
) |
|
|
embedder_model = models[0] |
|
|
embedder_model = embedder_model.to(device) |
|
|
if device != "cpu": |
|
|
embedder_model = embedder_model.half() |
|
|
else: |
|
|
embedder_model = embedder_model.float() |
|
|
embedder_model.eval() |
|
|
except Exception as e: |
|
|
print(f"Error: {e} {embedder_path}") |
|
|
traceback.print_exc() |
|
|
|
|
|
return embedder_model, cfg |
|
|
|
|
|
|
|
|
|
|
|
def readwave(wav_path, normalize=False): |
|
|
wav, sr = sf.read(wav_path) |
|
|
assert sr == 16000 |
|
|
feats = torch.from_numpy(wav).float() |
|
|
if feats.dim() == 2: |
|
|
feats = feats.mean(-1) |
|
|
assert feats.dim() == 1, feats.dim() |
|
|
if normalize: |
|
|
with torch.no_grad(): |
|
|
feats = F.layer_norm(feats, feats.shape) |
|
|
feats = feats.view(1, -1) |
|
|
return feats |
|
|
|
|
|
|
|
|
def processor( |
|
|
todo: List[str], |
|
|
device: torch.device, |
|
|
embedder_path: str, |
|
|
embedder_load_from: str, |
|
|
embedding_channel: bool, |
|
|
embedding_output_layer: int, |
|
|
wav_dir: str, |
|
|
out_dir: str, |
|
|
process_id: int, |
|
|
): |
|
|
half_support = ( |
|
|
device.type == "cuda" and torch.cuda.get_device_capability(device)[0] >= 5.3 |
|
|
) |
|
|
is_feats_dim_768 = embedding_channel == 768 |
|
|
|
|
|
if embedder_load_from == "local" and not os.path.exists(embedder_path): |
|
|
return f"Embedder not found: {embedder_path}" |
|
|
|
|
|
model, cfg = load_embedder(embedder_path, device) |
|
|
|
|
|
for file in tqdm(todo, position=1 + process_id): |
|
|
try: |
|
|
if file.endswith(".wav"): |
|
|
wav_filepath = os.path.join(wav_dir, file) |
|
|
out_filepath = os.path.join(out_dir, file.replace("wav", "npy")) |
|
|
|
|
|
if os.path.exists(out_filepath): |
|
|
continue |
|
|
|
|
|
os.makedirs(os.path.dirname(out_filepath), exist_ok=True) |
|
|
|
|
|
is_normalize = False if cfg is None else cfg.task.normalize |
|
|
feats = readwave(wav_filepath, normalize=is_normalize) |
|
|
padding_mask = torch.BoolTensor(feats.shape).fill_(False) |
|
|
if isinstance(model, tuple): |
|
|
feats = model[0]( |
|
|
feats.squeeze(0).squeeze(0).to(device), |
|
|
return_tensors="pt", |
|
|
sampling_rate=16000, |
|
|
) |
|
|
if half_support: |
|
|
feats = feats.input_values.to(device).half() |
|
|
else: |
|
|
feats = feats.input_values.to(device).float() |
|
|
|
|
|
with torch.no_grad(): |
|
|
if half_support: |
|
|
if is_feats_dim_768: |
|
|
feats = model[1](feats).last_hidden_state |
|
|
else: |
|
|
feats = model[1](feats).extract_features |
|
|
else: |
|
|
if is_feats_dim_768: |
|
|
feats = model[1].float()(feats).last_hidden_state |
|
|
else: |
|
|
feats = model[1].float()(feats).extract_features |
|
|
else: |
|
|
inputs = { |
|
|
"source": feats.half().to(device) |
|
|
if half_support |
|
|
else feats.to(device), |
|
|
"padding_mask": padding_mask.to(device), |
|
|
"output_layer": embedding_output_layer, |
|
|
} |
|
|
|
|
|
|
|
|
if not half_support: |
|
|
model = model.float() |
|
|
inputs["source"] = inputs["source"].float() |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model.extract_features(**inputs) |
|
|
if is_feats_dim_768: |
|
|
feats = logits[0] |
|
|
else: |
|
|
feats = model.final_proj(logits[0]) |
|
|
|
|
|
feats = feats.squeeze(0).float().cpu().numpy() |
|
|
if np.isnan(feats).sum() == 0: |
|
|
np.save(out_filepath, feats, allow_pickle=False) |
|
|
else: |
|
|
print(f"{file} contains nan") |
|
|
except Exception as e: |
|
|
print(f"Error: {e} {file}") |
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
def run( |
|
|
training_dir: str, |
|
|
embedder_path: str, |
|
|
embedder_load_from: str, |
|
|
embedding_channel: int, |
|
|
embedding_output_layer: int, |
|
|
gpu_ids: List[int], |
|
|
device: Optional[Union[torch.device, str]] = None, |
|
|
): |
|
|
wav_dir = os.path.join(training_dir, "1_16k_wavs") |
|
|
out_dir = os.path.join(training_dir, "3_feature256") |
|
|
|
|
|
num_gpus = len(gpu_ids) |
|
|
|
|
|
for gpu_id in gpu_ids: |
|
|
if num_gpus < gpu_id + 1: |
|
|
print(f"GPU {gpu_id} is not available") |
|
|
return |
|
|
|
|
|
if os.path.exists(out_dir): |
|
|
return |
|
|
|
|
|
os.makedirs(out_dir, exist_ok=True) |
|
|
|
|
|
todo = [ |
|
|
os.path.join(dir, f) |
|
|
for dir in sorted(list(os.listdir(wav_dir))) |
|
|
if os.path.isdir(os.path.join(wav_dir, dir)) |
|
|
for f in sorted(list(os.listdir(os.path.join(wav_dir, dir)))) |
|
|
] |
|
|
|
|
|
if device is not None: |
|
|
if type(device) == str: |
|
|
device = torch.device(device) |
|
|
if device.type == "mps": |
|
|
device = torch.device( |
|
|
"cpu" |
|
|
) |
|
|
processor( |
|
|
todo, |
|
|
device, |
|
|
embedder_path, |
|
|
embedder_load_from, |
|
|
embedding_channel, |
|
|
embedding_output_layer, |
|
|
wav_dir, |
|
|
out_dir, |
|
|
process_id=0, |
|
|
) |
|
|
else: |
|
|
with ProcessPoolExecutor(mp_context=mp.get_context("spawn")) as executor: |
|
|
for i, id in enumerate(gpu_ids): |
|
|
executor.submit( |
|
|
processor, |
|
|
todo[i::num_gpus], |
|
|
torch.device(f"cuda:{id}"), |
|
|
embedder_path, |
|
|
embedder_load_from, |
|
|
embedding_channel, |
|
|
embedding_output_layer, |
|
|
wav_dir, |
|
|
out_dir, |
|
|
process_id=i, |
|
|
) |
|
|
|