| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | import torch |
| | from tqdm import tqdm |
| | import numpy as np |
| |
|
| | from transformers import Wav2Vec2FeatureExtractor |
| | from transformers import AutoModel |
| | import torchaudio |
| | import torchaudio.transforms as T |
| | from sklearn.preprocessing import StandardScaler |
| |
|
| |
|
| | def mert_encoder(model, processor, audio_path, hps): |
| | """ |
| | # mert default sr: 24000 |
| | """ |
| | with torch.no_grad(): |
| | resample_rate = processor.sampling_rate |
| | device = next(model.parameters()).device |
| |
|
| | input_audio, sampling_rate = torchaudio.load(audio_path) |
| | input_audio = input_audio.squeeze() |
| |
|
| | if sampling_rate != resample_rate: |
| | resampler = T.Resample(sampling_rate, resample_rate) |
| | input_audio = resampler(input_audio) |
| |
|
| | inputs = processor( |
| | input_audio, sampling_rate=resample_rate, return_tensors="pt" |
| | ).to( |
| | device |
| | ) |
| |
|
| | outputs = model(**inputs, output_hidden_states=True) |
| |
|
| | |
| | |
| | |
| |
|
| | feature = outputs.hidden_states[ |
| | hps.mert_feature_layer |
| | ].squeeze() |
| |
|
| | return feature.cpu().detach().numpy() |
| |
|
| |
|
| | def mert_features_normalization(raw_mert_features): |
| | normalized_mert_features = list() |
| |
|
| | mert_features = np.array(raw_mert_features) |
| | scaler = StandardScaler().fit(mert_features) |
| | for raw_mert_feature in raw_mert_feature: |
| | normalized_mert_feature = scaler.transform(raw_mert_feature) |
| | normalized_mert_features.append(normalized_mert_feature) |
| | return normalized_mert_features |
| |
|
| |
|
| | def get_mapped_mert_features(raw_mert_features, mapping_features, fast_mapping=True): |
| | source_hop = 320 |
| | target_hop = 256 |
| |
|
| | factor = np.gcd(source_hop, target_hop) |
| | source_hop //= factor |
| | target_hop //= factor |
| | print( |
| | "Mapping source's {} frames => target's {} frames".format( |
| | target_hop, source_hop |
| | ) |
| | ) |
| |
|
| | mert_features = [] |
| | for index, mapping_feat in enumerate(tqdm(mapping_features)): |
| | |
| | target_len = mapping_feat.shape[0] |
| |
|
| | |
| | raw_feats = raw_mert_features[index].cpu().numpy() |
| | source_len, width = raw_feats.shape |
| |
|
| | |
| | const = source_len * source_hop // target_hop * target_hop |
| |
|
| | |
| | up_sampling_feats = np.repeat(raw_feats, source_hop, axis=0) |
| | |
| | down_sampling_feats = np.average( |
| | up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1 |
| | ) |
| |
|
| | err = abs(target_len - len(down_sampling_feats)) |
| | if err > 3: |
| | print("index:", index) |
| | print("mels:", mapping_feat.shape) |
| | print("raw mert vector:", raw_feats.shape) |
| | print("up_sampling:", up_sampling_feats.shape) |
| | print("const:", const) |
| | print("down_sampling_feats:", down_sampling_feats.shape) |
| | exit() |
| | if len(down_sampling_feats) < target_len: |
| | |
| | end = down_sampling_feats[-1][None, :].repeat(err, axis=0) |
| | down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0) |
| |
|
| | |
| | feats = down_sampling_feats[:target_len] |
| | mert_features.append(feats) |
| |
|
| | return mert_features |
| |
|
| |
|
| | def load_mert_model(hps): |
| | print("Loading MERT Model: ", hps.mert_model) |
| |
|
| | |
| | model_name = hps.mert_model |
| | model = AutoModel.from_pretrained(model_name, trust_remote_code=True) |
| |
|
| | if torch.cuda.is_available(): |
| | model = model.cuda() |
| |
|
| | |
| |
|
| | preprocessor = Wav2Vec2FeatureExtractor.from_pretrained( |
| | model_name, trust_remote_code=True |
| | ) |
| | return model, preprocessor |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|