| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import torch |
| | import numpy as np |
| | import yaml |
| | import copy |
| | from tqdm import tqdm |
| | from torchaudio.compliance import kaldi |
| | from torch.nn.utils.rnn import pad_sequence |
| | from torch.utils.data import DataLoader |
| | from fairseq import checkpoint_utils |
| | from transformers import AutoModel, Wav2Vec2FeatureExtractor |
| |
|
| | from utils.io_optim import ( |
| | TorchaudioDataset, |
| | LibrosaDataset, |
| | FFmpegDataset, |
| | collate_batch, |
| | ) |
| | from modules import whisper_extractor as whisper |
| | from modules.wenet_extractor.utils.init_model import init_model |
| | from modules.wenet_extractor.utils.checkpoint import load_checkpoint |
| |
|
| | """ |
| | Extractor for content features |
| | 1. whisper |
| | 2. contentvec |
| | 3. wenet |
| | 4. mert |
| | |
| | Pipeline: |
| | in preprocess.py: |
| | call extract_utt_content_features() to extract content features for each utterance |
| | extract_utt_content_features() envelopes the following steps: |
| | 1. load the model (whisper, contentvec, wenet) |
| | 2. extract the content features |
| | 3. save the content features into files |
| | in svc_dataset.py: |
| | call offline_align() to align the content features to the given target length |
| | |
| | """ |
| |
|
| | """ |
| | Extractor Usage: |
| | 1. initialize an instance of extractor |
| | extractor = WhisperExtractor(cfg) |
| | 2. load the specified model |
| | extractor.load_model() |
| | 3. extract the content features |
| | extractor.extract_content(utt) for single utterance |
| | extractor.extract_content_batch(utts) for batch utterances |
| | 4. save the content features |
| | extractor.save_feature(utt, content_feature) for single utterance |
| | """ |
| |
|
| |
|
| | class BaseExtractor: |
| | def __init__(self, cfg): |
| | self.cfg = cfg |
| | self.extractor_type = None |
| | self.model = None |
| |
|
| | def offline_align(self, content, target_len): |
| | """ |
| | args: |
| | content: (source_len, dim) |
| | target_len: target length |
| | return: |
| | mapped_feature: (target_len, dim) |
| | """ |
| | target_hop = self.cfg.preprocess.hop_size |
| |
|
| | assert self.extractor_type in ["whisper", "contentvec", "wenet"] |
| | if self.extractor_type == "whisper": |
| | source_hop = ( |
| | self.cfg.preprocess.whisper_frameshift |
| | * self.cfg.preprocess.whisper_downsample_rate |
| | * self.cfg.preprocess.sample_rate |
| | ) |
| | elif self.extractor_type == "contentvec": |
| | source_hop = ( |
| | self.cfg.preprocess.contentvec_frameshift |
| | * self.cfg.preprocess.sample_rate |
| | ) |
| | elif self.extractor_type == "wenet": |
| | source_hop = ( |
| | self.cfg.preprocess.wenet_frameshift |
| | * self.cfg.preprocess.wenet_downsample_rate |
| | * self.cfg.preprocess.sample_rate |
| | ) |
| | source_hop = int(source_hop) |
| | factor = np.gcd(source_hop, target_hop) |
| | source_hop //= factor |
| | target_hop //= factor |
| |
|
| | |
| | _, width = content.shape |
| | |
| | source_len = min(target_len * target_hop // source_hop + 1, len(content)) |
| |
|
| | |
| | const = source_len * source_hop // target_hop * target_hop |
| |
|
| | |
| | up_sampling_feats = np.repeat(content, source_hop, axis=0) |
| | |
| | down_sampling_feats = np.average( |
| | up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1 |
| | ) |
| |
|
| | err = abs(target_len - len(down_sampling_feats)) |
| | if err > 8: |
| | |
| | err_log_dir = os.path.join( |
| | self.cfg.preprocess.processed_dir, "align_max_err.log" |
| | ) |
| | try: |
| | with open(err_log_dir, "r") as f: |
| | err_num = int(f.read()) |
| | except: |
| | with open(err_log_dir, "w") as f: |
| | f.write("0") |
| | err_num = 0 |
| | if err > err_num: |
| | with open(err_log_dir, "w") as f: |
| | f.write(str(err)) |
| |
|
| | if len(down_sampling_feats) < target_len: |
| | |
| | end = down_sampling_feats[-1][None, :].repeat(err, axis=0) |
| | down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0) |
| |
|
| | |
| | mapped_feature = down_sampling_feats[:target_len] |
| |
|
| | return mapped_feature |
| |
|
| | def save_feature(self, utt, content_feature): |
| | """Save a single utternace to path {cfg.preprocess.processed_dir} |
| | |
| | Args: |
| | utt (dict): one item in metadata, containing information for one utterance |
| | content_feature (tensor): content feature of one utterance |
| | """ |
| | uid = utt["Uid"] |
| | assert self.extractor_type != None |
| | out_dir = os.path.join( |
| | self.cfg.preprocess.processed_dir, utt["Dataset"], self.extractor_type |
| | ) |
| | os.makedirs(out_dir, exist_ok=True) |
| | save_path = os.path.join(out_dir, uid + ".npy") |
| | |
| | duration = utt["Duration"] |
| | if self.extractor_type == "whisper": |
| | frameshift = ( |
| | self.cfg.preprocess.whisper_frameshift |
| | * self.cfg.preprocess.whisper_downsample_rate |
| | ) |
| | elif self.extractor_type == "contentvec": |
| | frameshift = self.cfg.preprocess.contentvec_frameshift |
| | elif self.extractor_type == "wenet": |
| | frameshift = ( |
| | self.cfg.preprocess.wenet_frameshift |
| | * self.cfg.preprocess.wenet_downsample_rate |
| | ) |
| | elif self.extractor_type == "mert": |
| | frameshift = self.cfg.preprocess.mert_frameshift |
| | else: |
| | raise NotImplementedError |
| | |
| | num_frames = int(np.ceil((duration - frameshift) / frameshift)) + 1 |
| | |
| | assert ( |
| | len(content_feature.shape) == 2 |
| | ), "content feature shape error, it should be (num_frames, dim)" |
| | content_feature = content_feature[:num_frames, :] |
| | np.save(save_path, content_feature.cpu().detach().numpy()) |
| |
|
| |
|
| | class WhisperExtractor(BaseExtractor): |
| | def __init__(self, config): |
| | super(WhisperExtractor, self).__init__(config) |
| | self.extractor_type = "whisper" |
| |
|
| | def load_model(self): |
| | |
| | print("Loading Whisper Model...") |
| |
|
| | checkpoint_file = ( |
| | self.cfg.preprocess.whisper_model_path |
| | if "whisper_model_path" in self.cfg.preprocess |
| | else None |
| | ) |
| | model = whisper.load_model( |
| | self.cfg.preprocess.whisper_model, checkpoint_file=checkpoint_file |
| | ) |
| | if torch.cuda.is_available(): |
| | print("Using GPU...\n") |
| | model = model.cuda() |
| | else: |
| | print("Using CPU...\n") |
| |
|
| | self.model = model.eval() |
| |
|
| | def extract_content_features(self, wavs, lens): |
| | """extract content features from a batch of dataloader |
| | Args: |
| | wavs: tensor (batch_size, T) |
| | lens: list |
| | """ |
| | |
| | wavs = whisper.pad_or_trim(wavs) |
| | |
| | batch_mel = whisper.log_mel_spectrogram(wavs).to(self.model.device) |
| | with torch.no_grad(): |
| | |
| | features = self.model.embed_audio(batch_mel) |
| | return features |
| |
|
| |
|
| | class ContentvecExtractor(BaseExtractor): |
| | def __init__(self, cfg): |
| | super(ContentvecExtractor, self).__init__(cfg) |
| | self.extractor_type = "contentvec" |
| |
|
| | def load_model(self): |
| | assert self.model == None |
| | |
| | ckpt_path = self.cfg.preprocess.contentvec_file |
| | print("Load Contentvec Model...") |
| |
|
| | models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( |
| | [ckpt_path], |
| | suffix="", |
| | ) |
| | model = models[0] |
| | model.eval() |
| |
|
| | if torch.cuda.is_available(): |
| | |
| | model = model.cuda() |
| |
|
| | self.model = model |
| |
|
| | def extract_content_features(self, wavs, lens): |
| | """extract content features from a batch of dataloader |
| | Args: |
| | wavs: tensor (batch, T) |
| | lens: list |
| | """ |
| | device = next(self.model.parameters()).device |
| | wavs = wavs.to(device) |
| | padding_mask = torch.eq(wavs, torch.zeros_like(wavs)).to(device) |
| | with torch.no_grad(): |
| | logits = self.model.extract_features( |
| | source=wavs, padding_mask=padding_mask, output_layer=12 |
| | ) |
| | |
| | feats = self.model.final_proj(logits[0]) |
| | return feats |
| |
|
| |
|
| | class WenetExtractor(BaseExtractor): |
| | def __init__(self, config): |
| | super(WenetExtractor, self).__init__(config) |
| | self.extractor_type = "wenet" |
| |
|
| | def load_model(self): |
| | wenet_cfg = self.cfg.preprocess.wenet_config |
| | wenet_model_path = self.cfg.preprocess.wenet_model_path |
| | |
| | with open(wenet_cfg, "r") as w: |
| | wenet_configs = yaml.load(w, Loader=yaml.FullLoader) |
| | self.extract_conf = copy.deepcopy(wenet_configs["dataset_conf"]) |
| | print("Loading Wenet Model...") |
| | self.model = init_model(wenet_configs) |
| | load_checkpoint(self.model, wenet_model_path) |
| |
|
| | if torch.cuda.is_available(): |
| | print("Using GPU...\n") |
| | self.model = self.model.cuda() |
| | else: |
| | print("Using CPU...\n") |
| |
|
| | self.model = self.model.eval() |
| |
|
| | def extract_content_features(self, wavs, lens): |
| | """extract content features from a batch of dataloader |
| | Args: |
| | wavs: tensor |
| | lens: list |
| | """ |
| | feats_list = [] |
| | lengths_list = [] |
| |
|
| | device = next(self.model.parameters()).device |
| | |
| | assert self.extract_conf is not None, "load model first!" |
| | feats_type = self.extract_conf.get("feats_type", "fbank") |
| | assert feats_type in ["fbank", "mfcc"] |
| |
|
| | for idx, wav in enumerate(wavs): |
| | |
| | wav = wav[: lens[idx]].to(device) |
| |
|
| | |
| | pad_tensor = torch.zeros(160, device=wav.device) |
| | wav = torch.cat((wav, pad_tensor), dim=-1) |
| | wav *= 1 << 15 |
| |
|
| | wav = wav.unsqueeze(0) |
| | if feats_type == "fbank": |
| | fbank_conf = self.extract_conf.get("fbank_conf", {}) |
| | feat = kaldi.fbank( |
| | wav, |
| | sample_frequency=16000, |
| | num_mel_bins=fbank_conf["num_mel_bins"], |
| | frame_length=fbank_conf["frame_length"], |
| | frame_shift=fbank_conf["frame_shift"], |
| | dither=fbank_conf["dither"], |
| | ) |
| | elif feats_type == "mfcc": |
| | mfcc_conf = self.extract_conf.get("mfcc", {}) |
| | feat = kaldi.mfcc( |
| | wav, |
| | sample_frequency=16000, |
| | num_mel_bins=mfcc_conf["num_mel_bins"], |
| | frame_length=mfcc_conf["frame_length"], |
| | frame_shift=mfcc_conf["frame_shift"], |
| | dither=mfcc_conf["dither"], |
| | num_ceps=mfcc_conf.get("num_ceps", 40), |
| | high_freq=mfcc_conf.get("high_freq", 0.0), |
| | low_freq=mfcc_conf.get("low_freq", 20.0), |
| | ) |
| | feats_list.append(feat) |
| | lengths_list.append(feat.shape[0]) |
| |
|
| | feats_lengths = torch.tensor(lengths_list, dtype=torch.int32).to(device) |
| | feats_tensor = pad_sequence(feats_list, batch_first=True).to( |
| | device |
| | ) |
| |
|
| | features = self.model.encoder_extractor( |
| | feats_tensor, |
| | feats_lengths, |
| | decoding_chunk_size=-1, |
| | num_decoding_left_chunks=-1, |
| | simulate_streaming=False, |
| | ) |
| | return features |
| |
|
| |
|
| | class MertExtractor(BaseExtractor): |
| | def __init__(self, cfg): |
| | super(MertExtractor, self).__init__(cfg) |
| | self.extractor_type = "mert" |
| | self.preprocessor = None |
| |
|
| | def load_model(self): |
| | assert self.model == None |
| | assert self.preprocessor == None |
| |
|
| | print("Loading MERT Model: ...", self.cfg.preprocess.mert_model) |
| |
|
| | local_mert_path = "/mnt/workspace/fangzihao/acce/Amphion/pretrained/MERT" |
| |
|
| | model_name = self.cfg.preprocess.mert_model |
| | model = AutoModel.from_pretrained(local_mert_path, trust_remote_code=True) |
| |
|
| | if torch.cuda.is_available(): |
| | model = model.cuda() |
| | preprocessor = Wav2Vec2FeatureExtractor.from_pretrained( |
| | local_mert_path, trust_remote_code=True |
| | ) |
| |
|
| | self.model = model |
| | self.preprocessor = preprocessor |
| |
|
| | def extract_content_features(self, wavs, lens): |
| | """extract content features from a batch of dataloader |
| | Args: |
| | wavs: tensor (batch, T) |
| | lens: list |
| | """ |
| | with torch.no_grad(): |
| | sample_rate = self.preprocessor.sampling_rate |
| | device = next(self.model.parameters()).device |
| | assert ( |
| | sample_rate == self.cfg.preprocess.mert_sample_rate |
| | ), "mert sample rate mismatch, expected {}, got {}".format( |
| | self.cfg.preprocess.mert_sample_rate, sample_rate |
| | ) |
| | mert_features = [] |
| | |
| | for wav in wavs: |
| | |
| | inputs = self.preprocessor( |
| | wavs, sampling_rate=sample_rate, return_tensors="pt" |
| | ).to(device) |
| |
|
| | outputs = self.model(**inputs, output_hidden_states=True) |
| | |
| | all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze() |
| | |
| | feature = outputs.hidden_states[ |
| | self.cfg.preprocess.mert_feature_layer |
| | ].squeeze(0) |
| | mert_features.append(feature) |
| |
|
| | return mert_features |
| |
|
| |
|
| | def extract_utt_content_features_dataloader(cfg, metadata, num_workers): |
| | dataset_name = metadata[0]["Dataset"] |
| |
|
| | if cfg.preprocess.extract_whisper_feature: |
| | feat_dir = os.path.join(cfg.preprocess.processed_dir, dataset_name, "whisper") |
| | os.makedirs(feat_dir, exist_ok=True) |
| | feat_files_num = len(os.listdir(feat_dir)) |
| |
|
| | if feat_files_num != len(metadata): |
| | whisper_waveforms = FFmpegDataset( |
| | cfg, dataset_name, cfg.preprocess.whisper_sample_rate, metadata=metadata |
| | ) |
| | data_loader = DataLoader( |
| | whisper_waveforms, |
| | num_workers=num_workers, |
| | shuffle=False, |
| | pin_memory=cfg.preprocess.pin_memory, |
| | batch_size=cfg.preprocess.content_feature_batch_size, |
| | collate_fn=collate_batch, |
| | drop_last=False, |
| | ) |
| | extractor = WhisperExtractor(cfg) |
| | extractor.load_model() |
| | for batch_idx, items in enumerate(tqdm(data_loader)): |
| | _metadata, wavs, lens = items |
| |
|
| | batch_content_features = extractor.extract_content_features( |
| | wavs, |
| | lens, |
| | ) |
| | for index, utt in enumerate(_metadata): |
| | extractor.save_feature(utt, batch_content_features[index]) |
| |
|
| | if cfg.preprocess.extract_contentvec_feature: |
| | feat_dir = os.path.join( |
| | cfg.preprocess.processed_dir, dataset_name, "contentvec" |
| | ) |
| | os.makedirs(feat_dir, exist_ok=True) |
| | feat_files_num = len(os.listdir(feat_dir)) |
| |
|
| | if feat_files_num != len(metadata): |
| | contentvec_waveforms = LibrosaDataset( |
| | cfg, |
| | dataset_name, |
| | cfg.preprocess.contentvec_sample_rate, |
| | metadata=metadata, |
| | ) |
| | data_loader = DataLoader( |
| | contentvec_waveforms, |
| | num_workers=num_workers, |
| | shuffle=False, |
| | pin_memory=cfg.preprocess.pin_memory, |
| | batch_size=cfg.preprocess.content_feature_batch_size, |
| | collate_fn=collate_batch, |
| | drop_last=False, |
| | ) |
| | extractor = ContentvecExtractor(cfg) |
| | extractor.load_model() |
| | for batch_idx, items in enumerate(tqdm(data_loader)): |
| | _metadata, wavs, lens = items |
| |
|
| | batch_content_features = extractor.extract_content_features(wavs, lens) |
| | for index, utt in enumerate(_metadata): |
| | extractor.save_feature(utt, batch_content_features[index]) |
| |
|
| | if cfg.preprocess.extract_wenet_feature: |
| | feat_dir = os.path.join(cfg.preprocess.processed_dir, dataset_name, "wenet") |
| | os.makedirs(feat_dir, exist_ok=True) |
| | feat_files_num = len(os.listdir(feat_dir)) |
| |
|
| | if feat_files_num != len(metadata): |
| | wenet_waveforms = TorchaudioDataset( |
| | cfg, dataset_name, cfg.preprocess.wenet_sample_rate, metadata=metadata |
| | ) |
| | data_loader = DataLoader( |
| | wenet_waveforms, |
| | num_workers=num_workers, |
| | shuffle=False, |
| | pin_memory=cfg.preprocess.pin_memory, |
| | batch_size=cfg.preprocess.content_feature_batch_size, |
| | collate_fn=collate_batch, |
| | drop_last=False, |
| | ) |
| | extractor = WenetExtractor(cfg) |
| | extractor.load_model() |
| | for batch_idx, items in enumerate(tqdm(data_loader)): |
| | _metadata, wavs, lens = items |
| |
|
| | batch_content_features = extractor.extract_content_features( |
| | wavs, |
| | lens, |
| | ) |
| | for index, utt in enumerate(_metadata): |
| | extractor.save_feature(utt, batch_content_features[index]) |
| |
|
| | if cfg.preprocess.extract_mert_feature: |
| | feat_dir = os.path.join(cfg.preprocess.processed_dir, dataset_name, "mert") |
| | os.makedirs(feat_dir, exist_ok=True) |
| | feat_files_num = len(os.listdir(feat_dir)) |
| |
|
| | if feat_files_num != len(metadata): |
| | mert_waveforms = TorchaudioDataset( |
| | cfg, dataset_name, cfg.preprocess.mert_sample_rate, metadata=metadata |
| | ) |
| | data_loader = DataLoader( |
| | mert_waveforms, |
| | num_workers=num_workers, |
| | shuffle=False, |
| | pin_memory=cfg.preprocess.pin_memory, |
| | batch_size=cfg.preprocess.content_feature_batch_size, |
| | collate_fn=collate_batch, |
| | drop_last=False, |
| | ) |
| | extractor = MertExtractor(cfg) |
| | extractor.load_model() |
| | for batch_idx, items in enumerate(tqdm(data_loader)): |
| | _metadata, wavs, lens = items |
| |
|
| | batch_content_features = extractor.extract_content_features( |
| | wavs, |
| | lens, |
| | ) |
| | for index, utt in enumerate(_metadata): |
| | extractor.save_feature(utt, batch_content_features[index]) |
| |
|