File size: 8,379 Bytes
948d8c3
 
 
 
 
 
 
 
 
 
585a5b0
9ba0d10
357ef6e
948d8c3
9ba0d10
948d8c3
 
 
 
 
 
 
 
 
 
e744b4f
948d8c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585a5b0
9ba0d10
 
585a5b0
 
948d8c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b0482d
948d8c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503beb6
948d8c3
 
503beb6
948d8c3
 
 
 
 
 
 
 
 
 
0cc5fb5
948d8c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import os
import sys
import json
import traceback
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import soundfile as sf
import librosa
from fairseq.tasks.audio_pretraining import AudioPretrainingTask
from fairseq import checkpoint_utils
from transformers import HubertModel, Wav2Vec2Model


# Environment settings for MPS fallback
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"

# Configuration class for device and precision management
class Config:
    def __init__(self, device):
        self.device = device if torch.cuda.is_available() else "cpu"
        self.is_half = self.device != "cpu"
        self.version_config_paths = [
            os.path.join("", f"{k}.json") for k in ["32k", "40k", "48k", "48k_v2", "32k_v2"]
        ]
        self.json_config = self.load_config_json()
        self.device_config()

    def load_config_json(self):
        configs = {}
        for config_file in self.version_config_paths:
            config_path = os.path.join("configs", config_file)
            with open(config_path, "r") as f:
                configs[config_file] = json.load(f)
        return configs

    def device_config(self):
        if self.device.startswith("cuda"):
            i_device = int(self.device.split(":")[-1])
            gpu_mem = torch.cuda.get_device_properties(i_device).total_memory // (1024**3)
            self.is_half = gpu_mem > 4 and "V100" in torch.cuda.get_device_name(i_device)
        elif torch.backends.mps.is_available():
            self.device = "mps"
            self.is_half = False
        else:
            self.device = "cpu"
            self.is_half = False

# Model-specific definitions
class HubertModelWithFinalProj(HubertModel):
    def __init__(self, config):
        super().__init__(config)
        self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)

def load_hubert_fairseq(model_path, device, is_half):
    # Load checkpoint and force audio_pretraining task
    saved_state = checkpoint_utils.load_checkpoint_to_cpu(model_path)
    saved_cfg = saved_state["cfg"]
    task = AudioPretrainingTask.setup_task(saved_cfg.task)
    models, saved_cfg, _ = checkpoint_utils.load_model_ensemble_and_task([model_path], task=task)
    model = models[0].to(device)
    if is_half and device not in ["mps", "cpu"]:
        model = model.half()
    model.eval()
    return {"model": model, "saved_cfg": saved_cfg}

def load_huggingface_model(model_path, device, is_half, model_class=HubertModelWithFinalProj):
    dtype = torch.float16 if is_half and "cuda" in device else torch.float32
    model = model_class.from_pretrained(model_path).to(device).to(dtype)
    model.eval()
    return {"model": model}

def hubert_preprocess(feats, saved_cfg):
    if saved_cfg.task.normalize:
        with torch.no_grad():
            feats = F.layer_norm(feats, feats.shape)
    return feats

def hubert_prepare_input(feats, device, version):
    padding_mask = torch.BoolTensor(feats.shape).fill_(False).to(device)
    output_layer = 9 if version == "v1" else 12
    return {
        "source": feats.half().to(device) if device not in ["mps", "cpu"] else feats.to(device),
        "padding_mask": padding_mask,
        "output_layer": output_layer,
    }

def hubert_extract_features(model, inputs):
    with torch.no_grad():
        logits = model.extract_features(**inputs)
        feats = model.final_proj(logits[0]) if inputs["output_layer"] == 9 else logits[0]
    return feats

def general_preprocess(feats, *args):
    return feats

def general_prepare_input(feats, device):
    return feats.to(device)

def general_extract_features(model, inputs):
    with torch.no_grad():
        feats = model(inputs)["last_hidden_state"]
    return feats

# Model configurations
model_configs = {
    "hubert": {
        "target_sr": 16000,
        "load_model": load_hubert_fairseq,
        "preprocess": hubert_preprocess,
        "prepare_input": hubert_prepare_input,
        "extract_features": hubert_extract_features,
    },
    "contentvec": {
        "target_sr": 16000,
        "load_model": lambda path, dev, half: load_huggingface_model(path, dev, half, ContentVecModel),
        "preprocess": general_preprocess,
        "prepare_input": general_prepare_input,
        "extract_features": general_extract_features,
    },
    "wav2vec": {
        "target_sr": 16000,
        "load_model": lambda path, dev, half: load_huggingface_model(path, dev, half, Wav2Vec2Model),
        "preprocess": general_preprocess,
        "prepare_input": general_prepare_input,
        "extract_features": general_extract_features,
    },
}

# Utility functions
def load_audio(file, target_sr):
    audio, sr = sf.read(file.strip())
    if audio.ndim > 1:
        audio = librosa.to_mono(audio.T)
    if sr != target_sr:
        audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
    return audio

def printt(f, strr):
    print(strr)
    f.write(f"{strr}\n")
    f.flush()

# Main script
def main():
    # Parse arguments
    device = sys.argv[1]
    n_part = int(sys.argv[2])
    i_part = int(sys.argv[3])
    exp_dir = sys.argv[4] if len(sys.argv) == 6 else sys.argv[5]
    version = sys.argv[5] if len(sys.argv) == 6 else sys.argv[6]
    model_path = sys.argv[7]
    model_name = sys.argv[8]

    if len(sys.argv) > 6:
        os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[4]

    config = Config(device)
    log_file = open(f"{exp_dir}/extract_f0_feature.log", "a+")
    printt(log_file, f"Args: {sys.argv}")

    # Resolve model path and name
    custom_mappings = {
        "wav2vec" : ("wav2vec_small_960h.pt", "wav2vec"),
        "hubert_base": ("hubert_base.pt", "hubert"),
        "contentvec_base": ("contentvec_base.pt", "contentvec"),
        "hubert_base_japanese" : ("hubert_base_japanese.pt","hubert_base_japanese")
    }
    if os.path.split(model_path)[-1] == "Custom" and model_name in custom_mappings:
        model_path, resolved_model_name = custom_mappings[model_name]
        model_name = resolved_model_name

    if not os.path.exists(model_path):
        printt(log_file, f"Error: {model_path} does not exist.")
        sys.exit(1)

    # Load model
    model_config = model_configs.get(model_name, model_configs[model_name])
    model_dict = model_config["load_model"](model_path, config.device, config.is_half)
    model = model_dict["model"]
    additional_configs = model_dict.get("saved_cfg")
    printt(log_file, f"Loaded model from {model_path} on {config.device}")

    # Setup directories
    feature_dim = 256 if version == "v1" else 768 if model_name != "hubert_large_ll60k" else 1024
    wav_path = f"{exp_dir}/1_16k_wavs"
    out_path = f"{exp_dir}/3_feature{feature_dim}"
    os.makedirs(out_path, exist_ok=True)

    # Process audio files
    todo = sorted(os.listdir(wav_path))[i_part::n_part]
    printt(log_file, f"Total files to process: {len(todo)}")
    if not todo:
        printt(log_file, "No files to process.")
        return

    target_sr = model_config["target_sr"]
    for idx, file in enumerate(todo):
        if not file.endswith(".wav"):
            continue
        try:
            wav_file = f"{wav_path}/{file}"
            out_file = f"{out_path}/{file.replace('.wav', '.npy')}"
            if os.path.exists(out_file):
                continue

            # Load and preprocess audio
            wav = load_audio(wav_file, target_sr)
            feats = torch.from_numpy(wav).float().view(1, -1)
            if feats.dim() > 2:
                feats = feats.mean(-1)
            preprocessed_feats = model_config["preprocess"](feats, additional_configs)
            inputs = model_config["prepare_input"](preprocessed_feats, config.device, version)
            feats = model_config["extract_features"](model, inputs)

            # Save features
            feats = feats.squeeze(0).float().cpu().numpy()
            if not np.isnan(feats).any():
                np.save(out_file, feats, allow_pickle=False)
                printt(log_file, f"Processed {file}: {feats.shape}")
            else:
                printt(log_file, f"{file} contains NaN values")
        except Exception:
            printt(log_file, traceback.format_exc())

    printt(log_file, "Feature extraction completed.")
    log_file.close()

if __name__ == "__main__":
    main()