| import argparse | |
| import os | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from omegaconf import OmegaConf | |
| from pydub import AudioSegment | |
| from tqdm import trange | |
| from transformers import ( | |
| AutoFeatureExtractor, | |
| BertForSequenceClassification, | |
| BertJapaneseTokenizer, | |
| Wav2Vec2ForXVector, | |
| ) | |
| class Embeder: | |
| def __init__(self, config): | |
| self.config = OmegaConf.load(config) | |
| self.df = pd.read_csv(config.path_csv) | |
| self.audio_feature_extractor = AutoFeatureExtractor.from_pretrained( | |
| "anton-l/wav2vec2-base-superb-sv" | |
| ) | |
| self.audio_model = Wav2Vec2ForXVector.from_pretrained( | |
| "anton-l/wav2vec2-base-superb-sv" | |
| ) | |
| self.text_tokenizer = BertJapaneseTokenizer.from_pretrained( | |
| "cl-tohoku/bert-base-japanese-whole-word-masking" | |
| ) | |
| self.text_model = BertForSequenceClassification.from_pretrained( | |
| "cl-tohoku/bert-base-japanese-whole-word-masking", | |
| num_labels=2, | |
| output_attentions=False, | |
| output_hidden_states=True, | |
| ).eval() | |
| def run(self): | |
| self._create_audio_embed() | |
| self._create_text_embed() | |
| def _create_audio_embed(self): | |
| audio_embed = None | |
| idx = [] | |
| for i in trange(len(self.df)): | |
| audio = [] | |
| song = AudioSegment.from_wav( | |
| os.path.join( | |
| self.config.path_data, | |
| "new_" + self.df.iloc[i]["filename"].replace(".mp3", ".wav"), | |
| ) | |
| ) | |
| song = np.array(song.get_array_of_samples(), dtype="float") | |
| audio.append(song) | |
| inputs = self.audio_feature_extractor( | |
| audio, | |
| sampling_rate=self.config.sample_rate, | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| try: | |
| with torch.no_grad(): | |
| embeddings = self.audio_model(**inputs).embeddings | |
| audio_embed = ( | |
| embeddings | |
| if audio_embed is None | |
| else torch.concatenate([audio_embed, embeddings]) | |
| ) | |
| except Exception: | |
| idx.append(i) | |
| audio_embed = torch.nn.functional.normalize(audio_embed, dim=-1).cpu() | |
| self.clean_and_save_data(audio_embed, idx) | |
| self.df = self.df.drop(index=idx) | |
| self.df.to_csv(self.config.path_csv, index=False) | |
| def _create_text_embed(self): | |
| text_embed = None | |
| for i in range(len(self.df)): | |
| sentence = self.df.iloc[i]["filename"].replace(".mp3", "") | |
| tokenized_text = self.text_tokenizer.tokenize(sentence) | |
| indexed_tokens = self.text_tokenizer.convert_tokens_to_ids(tokenized_text) | |
| tokens_tensor = torch.tensor([indexed_tokens]) | |
| with torch.no_grad(): | |
| all_encoder_layers = self.text_model(tokens_tensor) | |
| embedding = torch.mean(all_encoder_layers[1][-2][0], axis=0).reshape(1, -1) | |
| text_embed = ( | |
| embedding | |
| if text_embed is None | |
| else torch.concatenate([text_embed, embedding]) | |
| ) | |
| text_embed = torch.nn.functional.normalize(text_embed, dim=-1).cpu() | |
| torch.save(text_embed, self.config.path_text_embedding) | |
| def clean_and_save_data(self, audio_embed, idx): | |
| clean_embed = None | |
| for i in range(1, len(audio_embed)): | |
| if i in idx: | |
| continue | |
| else: | |
| clean_embed = ( | |
| audio_embed[i].reshape(1, -1) | |
| if clean_embed is None | |
| else torch.concatenate([clean_embed, audio_embed[i].reshape(1, -1)]) | |
| ) | |
| torch.save(clean_embed, self.config.path_audio_embedding) | |
| def argparser(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "-c", | |
| "--config", | |
| type=str, | |
| default="config.yaml", | |
| help="File path for config file.", | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| if __name__ == "__main__": | |
| args = argparser() | |
| embeder = Embeder(args.config) | |
| embeder.run() | |