| |
| |
| |
| |
|
|
| import os |
| import torch |
| import numpy as np |
| import yaml |
| import copy |
| from tqdm import tqdm |
| from torchaudio.compliance import kaldi |
| from torch.nn.utils.rnn import pad_sequence |
| from torch.utils.data import DataLoader |
| from fairseq import checkpoint_utils |
| from transformers import AutoModel, Wav2Vec2FeatureExtractor |
|
|
| from utils.io_optim import ( |
| TorchaudioDataset, |
| LibrosaDataset, |
| FFmpegDataset, |
| collate_batch, |
| ) |
| import whisper |
| from modules.wenet_extractor.utils.init_model import init_model |
| from modules.wenet_extractor.utils.checkpoint import load_checkpoint |
|
|
| """ |
| Extractor for content features |
| 1. whisper |
| 2. contentvec |
| 3. wenet |
| 4. mert |
| |
| Pipeline: |
| in preprocess.py: |
| call extract_utt_content_features() to extract content features for each utterance |
| extract_utt_content_features() envelopes the following steps: |
| 1. load the model (whisper, contentvec, wenet) |
| 2. extract the content features |
| 3. save the content features into files |
| in svc_dataset.py: |
| call offline_align() to align the content features to the given target length |
| |
| """ |
|
|
| """ |
| Extractor Usage: |
| 1. initialize an instance of extractor |
| extractor = WhisperExtractor(cfg) |
| 2. load the specified model |
| extractor.load_model() |
| 3. extract the content features |
| extractor.extract_content(utt) for single utterance |
| extractor.extract_content_batch(utts) for batch utterances |
| 4. save the content features |
| extractor.save_feature(utt, content_feature) for single utterance |
| """ |
|
|
|
|
| class AudioPretrainedModelFeaturesExtractor: |
| def __init__(self, cfg, extractor_type): |
| self.cfg = cfg |
| self.extractor_type = extractor_type |
| self.model = None |
| self.init_for_retrans() |
|
|
| def init_for_retrans(self): |
| target_hop = self.cfg.preprocess.hop_size |
|
|
| assert self.extractor_type in ["whisper", "contentvec", "wenet"] |
| if self.extractor_type == "whisper": |
| source_hop = ( |
| self.cfg.preprocess.whisper_frameshift |
| * self.cfg.preprocess.whisper_downsample_rate |
| * self.cfg.preprocess.sample_rate |
| ) |
| elif self.extractor_type == "contentvec": |
| source_hop = ( |
| self.cfg.preprocess.contentvec_frameshift |
| * self.cfg.preprocess.sample_rate |
| ) |
| elif self.extractor_type == "wenet": |
| source_hop = ( |
| self.cfg.preprocess.wenet_frameshift |
| * self.cfg.preprocess.wenet_downsample_rate |
| * self.cfg.preprocess.sample_rate |
| ) |
| source_hop = int(source_hop) |
| factor = np.gcd(source_hop, target_hop) |
| source_hop //= factor |
| target_hop //= factor |
|
|
| self.source_hop = source_hop |
| self.target_hop = target_hop |
|
|
| def offline_resolution_transformation(self, content, target_len): |
| """ |
| args: |
| content: (source_len, dim) |
| target_len: target length |
| return: |
| mapped_feature: (target_len, dim) |
| """ |
| source_hop = self.source_hop |
| target_hop = self.target_hop |
|
|
| |
| _, width = content.shape |
| |
| source_len = min(target_len * target_hop // source_hop + 1, len(content)) |
|
|
| |
| const = source_len * source_hop // target_hop * target_hop |
|
|
| |
| up_sampling_feats = np.repeat(content, source_hop, axis=0) |
| |
| down_sampling_feats = np.average( |
| up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1 |
| ) |
|
|
| err = abs(target_len - len(down_sampling_feats)) |
| if err > 8: |
| |
| err_log_dir = os.path.join( |
| self.cfg.preprocess.processed_dir, "align_max_err.log" |
| ) |
| try: |
| with open(err_log_dir, "r") as f: |
| err_num = int(f.read()) |
| except: |
| with open(err_log_dir, "w") as f: |
| f.write("0") |
| err_num = 0 |
| if err > err_num: |
| with open(err_log_dir, "w") as f: |
| f.write(str(err)) |
|
|
| if len(down_sampling_feats) < target_len: |
| |
| end = down_sampling_feats[-1][None, :].repeat(err, axis=0) |
| down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0) |
|
|
| |
| mapped_feature = down_sampling_feats[:target_len] |
|
|
| return mapped_feature |
|
|
| def log_for_ReTrans(self, err): |
| err_log_dir = os.path.join( |
| self.cfg.preprocess.processed_dir, "align_max_err.log" |
| ) |
| try: |
| with open(err_log_dir, "r") as f: |
| err_num = int(f.read()) |
| except: |
| with open(err_log_dir, "w") as f: |
| f.write("0") |
| err_num = 0 |
| if err > err_num: |
| with open(err_log_dir, "w") as f: |
| f.write(str(err)) |
|
|
| def ReTrans(self, source_feats, padded_target_len): |
| """ |
| Resolution Transformation for mismatched frames alginment. |
| |
| TODO: Merge the offline resolution_transformation into one |
| |
| args: |
| source_feats: Tensor, (B, padded_source_len, D) |
| padded_target_len: int, the maximum target length in a batch |
| return: |
| mapped_feature: Tensor, (B, padded_target_len, D) |
| """ |
| source_hop = self.source_hop |
| target_hop = self.target_hop |
|
|
| |
| B, padded_source_len, D = source_feats.shape |
|
|
| |
| source_len = min( |
| padded_target_len * target_hop // source_hop + 1, padded_source_len |
| ) |
|
|
| |
| const = source_len * source_hop // target_hop * target_hop |
|
|
| |
| up_sampling_feats = torch.repeat_interleave(source_feats, source_hop, dim=1)[ |
| :, :const |
| ] |
| |
| down_sampling_feats = torch.mean( |
| up_sampling_feats.reshape(B, -1, target_hop, D), dim=2 |
| ) |
|
|
| err = abs(padded_target_len - down_sampling_feats.shape[1]) |
| if err > 8: |
| self.log_for_ReTrans(err) |
|
|
| if down_sampling_feats.shape[1] < padded_target_len: |
| |
| end = down_sampling_feats[:, -1, :][:, None, :].repeat_interleave( |
| err, dim=1 |
| ) |
| |
| down_sampling_feats = torch.cat([down_sampling_feats, end], dim=1) |
|
|
| |
| mapped_feature = down_sampling_feats[:, :padded_target_len] |
| return mapped_feature |
|
|
| def get_valid_features(self, utt, content_feature): |
| |
| duration = utt["Duration"] |
| if self.extractor_type == "whisper": |
| frameshift = ( |
| self.cfg.preprocess.whisper_frameshift |
| * self.cfg.preprocess.whisper_downsample_rate |
| ) |
| elif self.extractor_type == "contentvec": |
| frameshift = self.cfg.preprocess.contentvec_frameshift |
| elif self.extractor_type == "wenet": |
| frameshift = ( |
| self.cfg.preprocess.wenet_frameshift |
| * self.cfg.preprocess.wenet_downsample_rate |
| ) |
| elif self.extractor_type == "mert": |
| frameshift = self.cfg.preprocess.mert_frameshift |
| else: |
| raise NotImplementedError |
|
|
| |
| num_frames = int(np.ceil((duration - frameshift) / frameshift)) + 1 |
| assert ( |
| len(content_feature.shape) == 2 |
| ), "content feature shape error, it should be (num_frames, dim)" |
| content_feature = content_feature[:num_frames, :] |
| return content_feature |
|
|
| def save_feature(self, utt, content_feature): |
| """Save a single utternace to path {cfg.preprocess.processed_dir} |
| |
| Args: |
| utt (dict): one item in metadata, containing information for one utterance |
| content_feature (tensor): content feature of one utterance |
| """ |
| uid = utt["Uid"] |
| assert self.extractor_type != None |
| out_dir = os.path.join( |
| self.cfg.preprocess.processed_dir, utt["Dataset"], self.extractor_type |
| ) |
| os.makedirs(out_dir, exist_ok=True) |
| save_path = os.path.join(out_dir, uid + ".npy") |
|
|
| content_feature = self.get_valid_features(utt, content_feature) |
| np.save(save_path, content_feature.cpu().detach().numpy()) |
|
|
|
|
| class WhisperExtractor(AudioPretrainedModelFeaturesExtractor): |
| def __init__(self, config): |
| super(WhisperExtractor, self).__init__(config, extractor_type="whisper") |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| def load_model(self): |
| |
| print("Loading Whisper Model...") |
|
|
| if "whisper_model_path" in self.cfg.preprocess: |
| if os.path.isfile(self.cfg.preprocess.whisper_model_path): |
| |
| download_root = os.path.dirname(self.cfg.preprocess.whisper_model_path) |
| elif os.path.isdir(self.cfg.preprocess.whisper_model_path): |
| |
| download_root = self.cfg.preprocess.whisper_model_path |
| else: |
| |
| download_root = self.cfg.preprocess.whisper_model_path |
| if download_root.endswith(".pt"): |
| download_root = os.path.dirname(download_root) |
| else: |
| download_root = None |
|
|
| model = whisper.load_model( |
| self.cfg.preprocess.whisper_model, self.device, download_root |
| ) |
| if torch.cuda.is_available(): |
| print("Using GPU...\n") |
| model = model.cuda() |
| else: |
| print("Using CPU...\n") |
|
|
| self.model = model.eval() |
|
|
| def extract_content_features(self, wavs): |
| """extract content features from a batch of dataloader |
| Args: |
| wavs: tensor (batch_size, T) |
| """ |
| |
| wavs = whisper.pad_or_trim(wavs) |
| |
| batch_mel = whisper.log_mel_spectrogram(wavs, device=self.model.device) |
| with torch.no_grad(): |
| |
| features = self.model.embed_audio(batch_mel) |
| return features |
|
|
|
|
| class ContentvecExtractor(AudioPretrainedModelFeaturesExtractor): |
| def __init__(self, cfg): |
| super(ContentvecExtractor, self).__init__(cfg, extractor_type="contentvec") |
|
|
| def load_model(self): |
| assert self.model == None |
| |
| ckpt_path = self.cfg.preprocess.contentvec_file |
| print("Load Contentvec Model...") |
|
|
| models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( |
| [ckpt_path], |
| suffix="", |
| ) |
| model = models[0] |
| model.eval() |
|
|
| if torch.cuda.is_available(): |
| |
| model = model.cuda() |
|
|
| self.model = model |
|
|
| def extract_content_features(self, wavs): |
| """extract content features from a batch of dataloader |
| Args: |
| wavs: tensor (batch, T) |
| """ |
| device = next(self.model.parameters()).device |
| wavs = wavs.to(device) |
| padding_mask = torch.eq(wavs, torch.zeros_like(wavs)).to(device) |
| with torch.no_grad(): |
| logits = self.model.extract_features( |
| source=wavs, padding_mask=padding_mask, output_layer=12 |
| ) |
| |
| feats = self.model.final_proj(logits[0]) |
| return feats |
|
|
|
|
| class WenetExtractor(AudioPretrainedModelFeaturesExtractor): |
| def __init__(self, config): |
| super(WenetExtractor, self).__init__(config, extractor_type="wenet") |
|
|
| def load_model(self): |
| wenet_cfg = self.cfg.preprocess.wenet_config |
| wenet_model_path = self.cfg.preprocess.wenet_model_path |
| |
| with open(wenet_cfg, "r") as w: |
| wenet_configs = yaml.load(w, Loader=yaml.FullLoader) |
| self.extract_conf = copy.deepcopy(wenet_configs["dataset_conf"]) |
| print("Loading Wenet Model...") |
| self.model = init_model(wenet_configs) |
| load_checkpoint(self.model, wenet_model_path) |
|
|
| if torch.cuda.is_available(): |
| print("Using GPU...\n") |
| self.model = self.model.cuda() |
| else: |
| print("Using CPU...\n") |
|
|
| self.model = self.model.eval() |
|
|
| def extract_content_features(self, wavs, lens): |
| """extract content features from a batch of dataloader |
| Args: |
| wavs: tensor, whose shape is (B, T) |
| lens: list |
| """ |
| feats_list = [] |
| lengths_list = [] |
|
|
| device = next(self.model.parameters()).device |
| |
| assert self.extract_conf is not None, "load model first!" |
| feats_type = self.extract_conf.get("feats_type", "fbank") |
| assert feats_type in ["fbank", "mfcc"] |
|
|
| for idx, wav in enumerate(wavs): |
| |
| wav = wav[: lens[idx]].to(device) |
|
|
| |
| pad_tensor = torch.zeros(160, device=wav.device) |
| wav = torch.cat((wav, pad_tensor), dim=-1) |
| wav *= 1 << 15 |
|
|
| wav = wav.unsqueeze(0) |
| if feats_type == "fbank": |
| fbank_conf = self.extract_conf.get("fbank_conf", {}) |
| feat = kaldi.fbank( |
| wav, |
| sample_frequency=16000, |
| num_mel_bins=fbank_conf["num_mel_bins"], |
| frame_length=fbank_conf["frame_length"], |
| frame_shift=fbank_conf["frame_shift"], |
| dither=fbank_conf["dither"], |
| ) |
| elif feats_type == "mfcc": |
| mfcc_conf = self.extract_conf.get("mfcc", {}) |
| feat = kaldi.mfcc( |
| wav, |
| sample_frequency=16000, |
| num_mel_bins=mfcc_conf["num_mel_bins"], |
| frame_length=mfcc_conf["frame_length"], |
| frame_shift=mfcc_conf["frame_shift"], |
| dither=mfcc_conf["dither"], |
| num_ceps=mfcc_conf.get("num_ceps", 40), |
| high_freq=mfcc_conf.get("high_freq", 0.0), |
| low_freq=mfcc_conf.get("low_freq", 20.0), |
| ) |
| feats_list.append(feat) |
| lengths_list.append(feat.shape[0]) |
|
|
| feats_lengths = torch.tensor(lengths_list, dtype=torch.int32).to(device) |
| feats_tensor = pad_sequence(feats_list, batch_first=True).to( |
| device |
| ) |
|
|
| features = self.model.encoder_extractor( |
| feats_tensor, |
| feats_lengths, |
| decoding_chunk_size=-1, |
| num_decoding_left_chunks=-1, |
| simulate_streaming=False, |
| ) |
| return features |
|
|
|
|
| class MertExtractor(AudioPretrainedModelFeaturesExtractor): |
| def __init__(self, cfg): |
| super(MertExtractor, self).__init__(cfg, extractor_type="mert") |
| self.preprocessor = None |
|
|
| def load_model(self): |
| assert self.model == None |
| assert self.preprocessor == None |
|
|
| print("Loading MERT Model: ...", self.cfg.preprocess.mert_model) |
|
|
| model_name = self.cfg.preprocess.mert_model |
| model = AutoModel.from_pretrained(model_name, trust_remote_code=True) |
|
|
| if torch.cuda.is_available(): |
| model = model.cuda() |
| preprocessor = Wav2Vec2FeatureExtractor.from_pretrained( |
| model_name, trust_remote_code=True |
| ) |
|
|
| self.model = model |
| self.preprocessor = preprocessor |
|
|
| def extract_content_features(self, wavs): |
| """extract content features from a batch of dataloader |
| Args: |
| wavs: tensor (batch, T) |
| """ |
| with torch.no_grad(): |
| sample_rate = self.preprocessor.sampling_rate |
| device = next(self.model.parameters()).device |
| assert ( |
| sample_rate == self.cfg.preprocess.mert_sample_rate |
| ), "mert sample rate mismatch, expected {}, got {}".format( |
| self.cfg.preprocess.mert_sample_rate, sample_rate |
| ) |
| mert_features = [] |
| |
| for wav in wavs: |
| |
| inputs = self.preprocessor( |
| wavs, sampling_rate=sample_rate, return_tensors="pt" |
| ).to(device) |
|
|
| outputs = self.model(**inputs, output_hidden_states=True) |
| |
| all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze() |
| |
| feature = outputs.hidden_states[ |
| self.cfg.preprocess.mert_feature_layer |
| ].squeeze(0) |
| mert_features.append(feature) |
|
|
| return mert_features |
|
|
|
|
| def extract_utt_content_features_dataloader(cfg, metadata, num_workers): |
| dataset_name = metadata[0]["Dataset"] |
| with torch.no_grad(): |
| if cfg.preprocess.extract_whisper_feature: |
| feat_dir = os.path.join( |
| cfg.preprocess.processed_dir, dataset_name, "whisper" |
| ) |
| os.makedirs(feat_dir, exist_ok=True) |
| feat_files_num = len(os.listdir(feat_dir)) |
|
|
| if feat_files_num != len(metadata): |
| whisper_waveforms = FFmpegDataset( |
| cfg, |
| dataset_name, |
| cfg.preprocess.whisper_sample_rate, |
| metadata=metadata, |
| ) |
| data_loader = DataLoader( |
| whisper_waveforms, |
| num_workers=num_workers, |
| shuffle=False, |
| pin_memory=cfg.preprocess.pin_memory, |
| batch_size=cfg.preprocess.content_feature_batch_size, |
| collate_fn=collate_batch, |
| drop_last=False, |
| ) |
| extractor = WhisperExtractor(cfg) |
| extractor.load_model() |
| for batch_idx, items in enumerate(tqdm(data_loader)): |
| _metadata, wavs, lens = items |
|
|
| batch_content_features = extractor.extract_content_features(wavs) |
| for index, utt in enumerate(_metadata): |
| extractor.save_feature(utt, batch_content_features[index]) |
|
|
| if cfg.preprocess.extract_contentvec_feature: |
| feat_dir = os.path.join( |
| cfg.preprocess.processed_dir, dataset_name, "contentvec" |
| ) |
| os.makedirs(feat_dir, exist_ok=True) |
| feat_files_num = len(os.listdir(feat_dir)) |
|
|
| if feat_files_num != len(metadata): |
| contentvec_waveforms = LibrosaDataset( |
| cfg, |
| dataset_name, |
| cfg.preprocess.contentvec_sample_rate, |
| metadata=metadata, |
| ) |
| data_loader = DataLoader( |
| contentvec_waveforms, |
| num_workers=num_workers, |
| shuffle=False, |
| pin_memory=cfg.preprocess.pin_memory, |
| batch_size=cfg.preprocess.content_feature_batch_size, |
| collate_fn=collate_batch, |
| drop_last=False, |
| ) |
| extractor = ContentvecExtractor(cfg) |
| extractor.load_model() |
| for batch_idx, items in enumerate(tqdm(data_loader)): |
| _metadata, wavs, lens = items |
|
|
| batch_content_features = extractor.extract_content_features(wavs) |
| for index, utt in enumerate(_metadata): |
| extractor.save_feature(utt, batch_content_features[index]) |
|
|
| if cfg.preprocess.extract_wenet_feature: |
| feat_dir = os.path.join(cfg.preprocess.processed_dir, dataset_name, "wenet") |
| os.makedirs(feat_dir, exist_ok=True) |
| feat_files_num = len(os.listdir(feat_dir)) |
|
|
| if feat_files_num != len(metadata): |
| wenet_waveforms = TorchaudioDataset( |
| cfg, |
| dataset_name, |
| cfg.preprocess.wenet_sample_rate, |
| metadata=metadata, |
| ) |
| data_loader = DataLoader( |
| wenet_waveforms, |
| num_workers=num_workers, |
| shuffle=False, |
| pin_memory=cfg.preprocess.pin_memory, |
| batch_size=cfg.preprocess.content_feature_batch_size, |
| collate_fn=collate_batch, |
| drop_last=False, |
| ) |
| extractor = WenetExtractor(cfg) |
| extractor.load_model() |
| for batch_idx, items in enumerate(tqdm(data_loader)): |
| _metadata, wavs, lens = items |
|
|
| batch_content_features = extractor.extract_content_features( |
| wavs, |
| lens, |
| ) |
| for index, utt in enumerate(_metadata): |
| extractor.save_feature(utt, batch_content_features[index]) |
|
|
| if cfg.preprocess.extract_mert_feature: |
| feat_dir = os.path.join(cfg.preprocess.processed_dir, dataset_name, "mert") |
| os.makedirs(feat_dir, exist_ok=True) |
| feat_files_num = len(os.listdir(feat_dir)) |
|
|
| if feat_files_num != len(metadata): |
| mert_waveforms = TorchaudioDataset( |
| cfg, |
| dataset_name, |
| cfg.preprocess.mert_sample_rate, |
| metadata=metadata, |
| ) |
| data_loader = DataLoader( |
| mert_waveforms, |
| num_workers=num_workers, |
| shuffle=False, |
| pin_memory=cfg.preprocess.pin_memory, |
| batch_size=cfg.preprocess.content_feature_batch_size, |
| collate_fn=collate_batch, |
| drop_last=False, |
| ) |
| extractor = MertExtractor(cfg) |
| extractor.load_model() |
| for batch_idx, items in enumerate(tqdm(data_loader)): |
| _metadata, wavs, lens = items |
|
|
| batch_content_features = extractor.extract_content_features(wavs) |
| for index, utt in enumerate(_metadata): |
| extractor.save_feature(utt, batch_content_features[index]) |
|
|