File size: 8,379 Bytes
948d8c3 585a5b0 9ba0d10 357ef6e 948d8c3 9ba0d10 948d8c3 e744b4f 948d8c3 585a5b0 9ba0d10 585a5b0 948d8c3 1b0482d 948d8c3 503beb6 948d8c3 503beb6 948d8c3 0cc5fb5 948d8c3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 | import os
import sys
import json
import traceback
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import soundfile as sf
import librosa
from fairseq.tasks.audio_pretraining import AudioPretrainingTask
from fairseq import checkpoint_utils
from transformers import HubertModel, Wav2Vec2Model
# Environment settings for MPS fallback
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
# Configuration class for device and precision management
class Config:
def __init__(self, device):
self.device = device if torch.cuda.is_available() else "cpu"
self.is_half = self.device != "cpu"
self.version_config_paths = [
os.path.join("", f"{k}.json") for k in ["32k", "40k", "48k", "48k_v2", "32k_v2"]
]
self.json_config = self.load_config_json()
self.device_config()
def load_config_json(self):
configs = {}
for config_file in self.version_config_paths:
config_path = os.path.join("configs", config_file)
with open(config_path, "r") as f:
configs[config_file] = json.load(f)
return configs
def device_config(self):
if self.device.startswith("cuda"):
i_device = int(self.device.split(":")[-1])
gpu_mem = torch.cuda.get_device_properties(i_device).total_memory // (1024**3)
self.is_half = gpu_mem > 4 and "V100" in torch.cuda.get_device_name(i_device)
elif torch.backends.mps.is_available():
self.device = "mps"
self.is_half = False
else:
self.device = "cpu"
self.is_half = False
# Model-specific definitions
class HubertModelWithFinalProj(HubertModel):
def __init__(self, config):
super().__init__(config)
self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
def load_hubert_fairseq(model_path, device, is_half):
# Load checkpoint and force audio_pretraining task
saved_state = checkpoint_utils.load_checkpoint_to_cpu(model_path)
saved_cfg = saved_state["cfg"]
task = AudioPretrainingTask.setup_task(saved_cfg.task)
models, saved_cfg, _ = checkpoint_utils.load_model_ensemble_and_task([model_path], task=task)
model = models[0].to(device)
if is_half and device not in ["mps", "cpu"]:
model = model.half()
model.eval()
return {"model": model, "saved_cfg": saved_cfg}
def load_huggingface_model(model_path, device, is_half, model_class=HubertModelWithFinalProj):
dtype = torch.float16 if is_half and "cuda" in device else torch.float32
model = model_class.from_pretrained(model_path).to(device).to(dtype)
model.eval()
return {"model": model}
def hubert_preprocess(feats, saved_cfg):
if saved_cfg.task.normalize:
with torch.no_grad():
feats = F.layer_norm(feats, feats.shape)
return feats
def hubert_prepare_input(feats, device, version):
padding_mask = torch.BoolTensor(feats.shape).fill_(False).to(device)
output_layer = 9 if version == "v1" else 12
return {
"source": feats.half().to(device) if device not in ["mps", "cpu"] else feats.to(device),
"padding_mask": padding_mask,
"output_layer": output_layer,
}
def hubert_extract_features(model, inputs):
with torch.no_grad():
logits = model.extract_features(**inputs)
feats = model.final_proj(logits[0]) if inputs["output_layer"] == 9 else logits[0]
return feats
def general_preprocess(feats, *args):
return feats
def general_prepare_input(feats, device):
return feats.to(device)
def general_extract_features(model, inputs):
with torch.no_grad():
feats = model(inputs)["last_hidden_state"]
return feats
# Model configurations
model_configs = {
"hubert": {
"target_sr": 16000,
"load_model": load_hubert_fairseq,
"preprocess": hubert_preprocess,
"prepare_input": hubert_prepare_input,
"extract_features": hubert_extract_features,
},
"contentvec": {
"target_sr": 16000,
"load_model": lambda path, dev, half: load_huggingface_model(path, dev, half, ContentVecModel),
"preprocess": general_preprocess,
"prepare_input": general_prepare_input,
"extract_features": general_extract_features,
},
"wav2vec": {
"target_sr": 16000,
"load_model": lambda path, dev, half: load_huggingface_model(path, dev, half, Wav2Vec2Model),
"preprocess": general_preprocess,
"prepare_input": general_prepare_input,
"extract_features": general_extract_features,
},
}
# Utility functions
def load_audio(file, target_sr):
audio, sr = sf.read(file.strip())
if audio.ndim > 1:
audio = librosa.to_mono(audio.T)
if sr != target_sr:
audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
return audio
def printt(f, strr):
print(strr)
f.write(f"{strr}\n")
f.flush()
# Main script
def main():
# Parse arguments
device = sys.argv[1]
n_part = int(sys.argv[2])
i_part = int(sys.argv[3])
exp_dir = sys.argv[4] if len(sys.argv) == 6 else sys.argv[5]
version = sys.argv[5] if len(sys.argv) == 6 else sys.argv[6]
model_path = sys.argv[7]
model_name = sys.argv[8]
if len(sys.argv) > 6:
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[4]
config = Config(device)
log_file = open(f"{exp_dir}/extract_f0_feature.log", "a+")
printt(log_file, f"Args: {sys.argv}")
# Resolve model path and name
custom_mappings = {
"wav2vec" : ("wav2vec_small_960h.pt", "wav2vec"),
"hubert_base": ("hubert_base.pt", "hubert"),
"contentvec_base": ("contentvec_base.pt", "contentvec"),
"hubert_base_japanese" : ("hubert_base_japanese.pt","hubert_base_japanese")
}
if os.path.split(model_path)[-1] == "Custom" and model_name in custom_mappings:
model_path, resolved_model_name = custom_mappings[model_name]
model_name = resolved_model_name
if not os.path.exists(model_path):
printt(log_file, f"Error: {model_path} does not exist.")
sys.exit(1)
# Load model
model_config = model_configs.get(model_name, model_configs[model_name])
model_dict = model_config["load_model"](model_path, config.device, config.is_half)
model = model_dict["model"]
additional_configs = model_dict.get("saved_cfg")
printt(log_file, f"Loaded model from {model_path} on {config.device}")
# Setup directories
feature_dim = 256 if version == "v1" else 768 if model_name != "hubert_large_ll60k" else 1024
wav_path = f"{exp_dir}/1_16k_wavs"
out_path = f"{exp_dir}/3_feature{feature_dim}"
os.makedirs(out_path, exist_ok=True)
# Process audio files
todo = sorted(os.listdir(wav_path))[i_part::n_part]
printt(log_file, f"Total files to process: {len(todo)}")
if not todo:
printt(log_file, "No files to process.")
return
target_sr = model_config["target_sr"]
for idx, file in enumerate(todo):
if not file.endswith(".wav"):
continue
try:
wav_file = f"{wav_path}/{file}"
out_file = f"{out_path}/{file.replace('.wav', '.npy')}"
if os.path.exists(out_file):
continue
# Load and preprocess audio
wav = load_audio(wav_file, target_sr)
feats = torch.from_numpy(wav).float().view(1, -1)
if feats.dim() > 2:
feats = feats.mean(-1)
preprocessed_feats = model_config["preprocess"](feats, additional_configs)
inputs = model_config["prepare_input"](preprocessed_feats, config.device, version)
feats = model_config["extract_features"](model, inputs)
# Save features
feats = feats.squeeze(0).float().cpu().numpy()
if not np.isnan(feats).any():
np.save(out_file, feats, allow_pickle=False)
printt(log_file, f"Processed {file}: {feats.shape}")
else:
printt(log_file, f"{file} contains NaN values")
except Exception:
printt(log_file, traceback.format_exc())
printt(log_file, "Feature extraction completed.")
log_file.close()
if __name__ == "__main__":
main() |