Pj12 commited on
Commit
948d8c3
·
verified ·
1 Parent(s): a2a173e

Upload extract_audio_features.py

Browse files
Files changed (1) hide show
  1. extract_audio_features.py +225 -0
extract_audio_features.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import traceback
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ import soundfile as sf
10
+ import librosa
11
+ from fairseq import checkpoint_utils
12
+ from transformers import HubertModel
13
+
14
+ # Environment settings for MPS fallback
15
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
16
+ os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
17
+
18
+ # Configuration class for device and precision management
19
+ class Config:
20
+ def __init__(self, device):
21
+ self.device = device if torch.cuda.is_available() else "cpu"
22
+ self.is_half = self.device != "cpu"
23
+ self.version_config_paths = [
24
+ os.path.join("", f"{k}.json") for k in ["32k", "40k", "48k", "48k_v2", "40k_v2", "32k_v2"]
25
+ ]
26
+ self.json_config = self.load_config_json()
27
+ self.device_config()
28
+
29
+ def load_config_json(self):
30
+ configs = {}
31
+ for config_file in self.version_config_paths:
32
+ config_path = os.path.join("configs", config_file)
33
+ with open(config_path, "r") as f:
34
+ configs[config_file] = json.load(f)
35
+ return configs
36
+
37
+ def device_config(self):
38
+ if self.device.startswith("cuda"):
39
+ i_device = int(self.device.split(":")[-1])
40
+ gpu_mem = torch.cuda.get_device_properties(i_device).total_memory // (1024**3)
41
+ self.is_half = gpu_mem > 4 and "V100" in torch.cuda.get_device_name(i_device)
42
+ elif torch.backends.mps.is_available():
43
+ self.device = "mps"
44
+ self.is_half = False
45
+ else:
46
+ self.device = "cpu"
47
+ self.is_half = False
48
+
49
+ # Model-specific definitions
50
+ class HubertModelWithFinalProj(HubertModel):
51
+ def __init__(self, config):
52
+ super().__init__(config)
53
+ self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
54
+
55
+ def load_hubert_fairseq(model_path, device, is_half):
56
+ models, saved_cfg, _ = checkpoint_utils.load_model_ensemble_and_task([model_path])
57
+ model = models[0].to(device)
58
+ if is_half and device not in ["mps", "cpu"]:
59
+ model = model.half()
60
+ model.eval()
61
+ return {"model": model, "saved_cfg": saved_cfg}
62
+
63
+ def load_huggingface_model(model_path, device, is_half, model_class=HubertModelWithFinalProj):
64
+ dtype = torch.float16 if is_half and "cuda" in device else torch.float32
65
+ model = model_class.from_pretrained(model_path).to(device).to(dtype)
66
+ model.eval()
67
+ return {"model": model}
68
+
69
+ def hubert_preprocess(feats, saved_cfg):
70
+ if saved_cfg.task.normalize:
71
+ with torch.no_grad():
72
+ feats = F.layer_norm(feats, feats.shape)
73
+ return feats
74
+
75
+ def hubert_prepare_input(feats, device, version):
76
+ padding_mask = torch.BoolTensor(feats.shape).fill_(False).to(device)
77
+ output_layer = 9 if version == "v1" else 12
78
+ return {
79
+ "source": feats.half().to(device) if device not in ["mps", "cpu"] else feats.to(device),
80
+ "padding_mask": padding_mask,
81
+ "output_layer": output_layer,
82
+ }
83
+
84
+ def hubert_extract_features(model, inputs):
85
+ with torch.no_grad():
86
+ logits = model.extract_features(**inputs)
87
+ feats = model.final_proj(logits[0]) if inputs["output_layer"] == 9 else logits[0]
88
+ return feats
89
+
90
+ def general_preprocess(feats, *args):
91
+ return feats
92
+
93
+ def general_prepare_input(feats, device):
94
+ return feats.to(device)
95
+
96
+ def general_extract_features(model, inputs):
97
+ with torch.no_grad():
98
+ feats = model(inputs)["last_hidden_state"]
99
+ return feats
100
+
101
+ # Model configurations
102
+ model_configs = {
103
+ "hubert": {
104
+ "target_sr": 16000,
105
+ "load_model": load_hubert_fairseq,
106
+ "preprocess": hubert_preprocess,
107
+ "prepare_input": hubert_prepare_input,
108
+ "extract_features": hubert_extract_features,
109
+ },
110
+ "contentvec": {
111
+ "target_sr": 16000,
112
+ "load_model": lambda path, dev, half: load_huggingface_model(path, dev, half, ContentVecModel),
113
+ "preprocess": general_preprocess,
114
+ "prepare_input": general_prepare_input,
115
+ "extract_features": general_extract_features,
116
+ },
117
+ "wav2vec": {
118
+ "target_sr": 16000,
119
+ "load_model": lambda path, dev, half: load_huggingface_model(path, dev, half, Wav2VecModel),
120
+ "preprocess": general_preprocess,
121
+ "prepare_input": general_prepare_input,
122
+ "extract_features": general_extract_features,
123
+ },
124
+ }
125
+
126
+ # Utility functions
127
+ def load_audio(file, target_sr):
128
+ audio, sr = sf.read(file.strip())
129
+ if audio.ndim > 1:
130
+ audio = librosa.to_mono(audio.T)
131
+ if sr != target_sr:
132
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
133
+ return audio
134
+
135
+ def printt(f, strr):
136
+ print(strr)
137
+ f.write(f"{strr}\n")
138
+ f.flush()
139
+
140
+ # Main script
141
+ def main():
142
+ # Parse arguments
143
+ device = sys.argv[1]
144
+ n_part = int(sys.argv[2])
145
+ i_part = int(sys.argv[3])
146
+ exp_dir = sys.argv[4] if len(sys.argv) == 6 else sys.argv[5]
147
+ version = sys.argv[5] if len(sys.argv) == 6 else sys.argv[6]
148
+ model_path = sys.argv[7]
149
+ model_name = sys.argv[8]
150
+
151
+ if len(sys.argv) > 6:
152
+ os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[4]
153
+
154
+ config = Config(device)
155
+ log_file = open(f"{exp_dir}/extract_f0_feature.log", "a+")
156
+ printt(log_file, f"Args: {sys.argv}")
157
+
158
+ # Resolve model path and name
159
+ custom_mappings = {
160
+ "hubert_base": ("hubert_base.pt", "hubert"),
161
+ "contentvec_base": ("contentvec_base.pt", "contentvec"),
162
+ "hubert_large_ll60k": ("hubert_large_ll60k.pt", "hubert"),
163
+ }
164
+ if os.path.split(model_path)[-1] == "Custom" and model_name in custom_mappings:
165
+ model_path, resolved_model_name = custom_mappings[model_name]
166
+ model_name = resolved_model_name
167
+
168
+ if not os.path.exists(model_path):
169
+ printt(log_file, f"Error: {model_path} does not exist.")
170
+ sys.exit(1)
171
+
172
+ # Load model
173
+ model_config = model_configs.get(model_name, model_configs["hubert"])
174
+ model_dict = model_config["load_model"](model_path, config.device, config.is_half)
175
+ model = model_dict["model"]
176
+ additional_configs = model_dict.get("saved_cfg")
177
+ printt(log_file, f"Loaded model from {model_path} on {config.device}")
178
+
179
+ # Setup directories
180
+ feature_dim = 256 if version == "v1" else 768 if model_name != "hubert_large_ll60k" else 1024
181
+ wav_path = f"{exp_dir}/1_16k_wavs"
182
+ out_path = f"{exp_dir}/3_feature{feature_dim}"
183
+ os.makedirs(out_path, exist_ok=True)
184
+
185
+ # Process audio files
186
+ todo = sorted(os.listdir(wav_path))[i_part::n_part]
187
+ printt(log_file, f"Total files to process: {len(todo)}")
188
+ if not todo:
189
+ printt(log_file, "No files to process.")
190
+ return
191
+
192
+ target_sr = model_config["target_sr"]
193
+ for idx, file in enumerate(todo):
194
+ if not file.endswith(".wav"):
195
+ continue
196
+ try:
197
+ wav_file = f"{wav_path}/{file}"
198
+ out_file = f"{out_path}/{file.replace('.wav', '.npy')}"
199
+ if os.path.exists(out_file):
200
+ continue
201
+
202
+ # Load and preprocess audio
203
+ wav = load_audio(wav_file, target_sr)
204
+ feats = torch.from_numpy(wav).float().view(1, -1)
205
+ if feats.dim() > 2:
206
+ feats = feats.mean(-1)
207
+ preprocessed_feats = model_config["preprocess"](feats, additional_configs)
208
+ inputs = model_config["prepare_input"](preprocessed_feats, config.device, version)
209
+ feats = model_config["extract_features"](model, inputs)
210
+
211
+ # Save features
212
+ feats = feats.squeeze(0).float().cpu().numpy()
213
+ if not np.isnan(feats).any():
214
+ np.save(out_file, feats, allow_pickle=False)
215
+ printt(log_file, f"Processed {file}: {feats.shape}")
216
+ else:
217
+ printt(log_file, f"{file} contains NaN values")
218
+ except Exception:
219
+ printt(log_file, traceback.format_exc())
220
+
221
+ printt(log_file, "Feature extraction completed.")
222
+ log_file.close()
223
+
224
+ if __name__ == "__main__":
225
+ main()