| |
| import os |
| import audresample |
| import torch |
| import matplotlib.pyplot as plt |
| import soundfile |
| import json |
| import audb |
| from transformers import AutoModelForAudioClassification |
| from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel |
| import types |
| from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline |
| import pandas as pd |
| import json |
| import numpy as np |
| from pathlib import Path |
| import transformers |
| import torch |
| import audmodel |
| import audiofile |
| import jiwer |
| |
| |
| |
| |
| |
| import msinference |
| import os |
| from random import shuffle |
|
|
| config = transformers.Wav2Vec2Config() |
| config.dev = torch.device('cuda:0') |
| config.dev2 = torch.device('cuda:0') |
|
|
|
|
|
|
|
|
| LABELS = ['arousal', 'dominance', 'valence', |
| 'Angry', |
| 'Sad', |
| 'Happy', |
| 'Surprise', |
| 'Fear', |
| 'Disgust', |
| 'Contempt', |
| 'Neutral' |
| ] |
|
|
| config = transformers.Wav2Vec2Config() |
| config.dev = torch.device('cuda:0') |
| config.dev2 = torch.device('cuda:0') |
|
|
|
|
|
|
|
|
| |
| |
| |
| |
| |
|
|
| |
| def _infer(self, x): |
| '''x: (batch, audio-samples-16KHz)''' |
| x = (x + self.config.mean) / self.config.std |
| x = self.ssl_model(x, attention_mask=None).last_hidden_state |
| |
| h = self.pool_model.sap_linear(x).tanh() |
| w = torch.matmul(h, self.pool_model.attention) |
| w = w.softmax(1) |
| mu = (x * w).sum(1) |
| x = torch.cat( |
| [ |
| mu, |
| ((x * x * w).sum(1) - mu * mu).clamp(min=1e-7).sqrt() |
| ], 1) |
| return self.ser_model(x) |
|
|
| teacher_cat = AutoModelForAudioClassification.from_pretrained( |
| '3loi/SER-Odyssey-Baseline-WavLM-Categorical-Attributes', |
| trust_remote_code=True |
| ).to(config.dev2).eval() |
| teacher_cat.forward = types.MethodType(_infer, teacher_cat) |
|
|
|
|
| |
| def _prenorm(x, attention_mask=None): |
| '''mean/var''' |
| if attention_mask is not None: |
| N = attention_mask.sum(1, keepdim=True) |
| x -= x.sum(1, keepdim=True) / N |
| var = (x * x).sum(1, keepdim=True) / N |
|
|
| else: |
| x -= x.mean(1, keepdim=True) |
| var = (x * x).mean(1, keepdim=True) |
| return x / torch.sqrt(var + 1e-7) |
|
|
| from torch import nn |
| from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel, Wav2Vec2Model |
| class RegressionHead(nn.Module): |
| r"""Classification head.""" |
|
|
| def __init__(self, config): |
|
|
| super().__init__() |
|
|
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| self.dropout = nn.Dropout(config.final_dropout) |
| self.out_proj = nn.Linear(config.hidden_size, config.num_labels) |
|
|
| def forward(self, features, **kwargs): |
|
|
| x = features |
| x = self.dropout(x) |
| x = self.dense(x) |
| x = torch.tanh(x) |
| x = self.dropout(x) |
| x = self.out_proj(x) |
|
|
| return x |
|
|
|
|
| class Dawn(Wav2Vec2PreTrainedModel): |
| r"""Speech emotion classifier.""" |
|
|
| def __init__(self, config): |
|
|
| super().__init__(config) |
|
|
| self.config = config |
| self.wav2vec2 = Wav2Vec2Model(config) |
| self.classifier = RegressionHead(config) |
| self.init_weights() |
|
|
| def forward( |
| self, |
| input_values, |
| attention_mask=None, |
| ): |
| x = _prenorm(input_values, attention_mask=attention_mask) |
| outputs = self.wav2vec2(x, attention_mask=attention_mask) |
| hidden_states = outputs[0] |
| hidden_states = torch.mean(hidden_states, dim=1) |
| logits = self.classifier(hidden_states) |
| return logits |
| |
| |
| dawn = Dawn.from_pretrained('audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim').to(config.dev).eval() |
| |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| torch_dtype = torch.float16 |
| model_id = "openai/whisper-large-v3" |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( |
| model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True |
| ).to(config.dev) |
| processor = AutoProcessor.from_pretrained(model_id) |
| _pipe = pipeline( |
| "automatic-speech-recognition", |
| model=model, |
| tokenizer=processor.tokenizer, |
| feature_extractor=processor.feature_extractor, |
| max_new_tokens=128, |
| chunk_length_s=30, |
| batch_size=16, |
| return_timestamps=True, |
| torch_dtype=torch_dtype, |
| device=config.dev, |
| ) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def process_function(x, sampling_rate, idx): |
| |
| |
| |
| |
| logits_cat = teacher_cat(torch.from_numpy(x).to(config.dev)).softmax(1) |
| logits_adv = dawn(torch.from_numpy(x).to(config.dev)) |
| |
| out = torch.cat([logits_adv, |
| logits_cat], |
| 1).cpu().detach().numpy() |
| |
| return out[0, :] |
|
|
|
|
|
|
| def load_speech(split=None): |
| DB = [ |
| |
| |
| |
| ['emodb', '1.2.0', 'emotion.categories.train.gold_standard', False], |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| ] |
|
|
| output_list = [] |
| for database_name, ver, table, has_timedeltas in DB: |
|
|
| a = audb.load(database_name, |
| sampling_rate=16000, |
| format='wav', |
| mixdown=True, |
| version=ver, |
| cache_root='/cache/audb/') |
| a = a[table].get() |
| if has_timedeltas: |
| print(f'{has_timedeltas=}') |
| |
| |
| |
| else: |
| output_list += [f for f in a.index] |
| return output_list |
|
|
|
|
|
|
|
|
|
|
| |
|
|
|
|
|
|
|
|
| |
| natural_wav_paths = load_speech() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| with open('harvard.json', 'r') as f: |
| harvard_individual_sentences = json.load(f)['sentences'] |
|
|
|
|
|
|
| synthetic_wav_paths = ['./enslow/' + i for i in |
| os.listdir('./enslow/')] |
| synthetic_wav_paths_4x = ['./style_vector_v2/' + i for i in |
| os.listdir('./style_vector_v2/')] |
| synthetic_wav_paths_foreign = ['./mimic3_foreign/' + i for i in os.listdir('./mimic3_foreign/') if 'en_U' not in i] |
| synthetic_wav_paths_foreign_4x = ['./mimic3_foreign_4x/' + i for i in os.listdir('./mimic3_foreign_4x/') if 'en_U' not in i] |
|
|
| |
| synthetic_wav_paths_foreign = [i for i in synthetic_wav_paths_foreign if audiofile.duration(i) > 2] |
| synthetic_wav_paths_foreign_4x = [i for i in synthetic_wav_paths_foreign_4x if audiofile.duration(i) > 2] |
| synthetic_wav_paths = [i for i in synthetic_wav_paths if audiofile.duration(i) > 2] |
| synthetic_wav_pathsn_4x = [i for i in synthetic_wav_paths_4x if audiofile.duration(i) > 2] |
|
|
| shuffle(synthetic_wav_paths_foreign_4x) |
| shuffle(synthetic_wav_paths_foreign) |
| shuffle(synthetic_wav_paths) |
| shuffle(synthetic_wav_paths_4x) |
| print(len(synthetic_wav_paths_foreign_4x), len(synthetic_wav_paths_foreign), |
| len(synthetic_wav_paths), len(synthetic_wav_paths_4x)) |
|
|
|
|
|
|
| for audio_prompt in ['english', |
| 'english_4x', |
| 'human', |
| 'foreign', |
| 'foreign_4x']: |
| |
| data = np.zeros((770, len(LABELS)*2 + 2)) |
| |
| |
| |
| |
| |
| OUT_FILE = f'{audio_prompt}_analytic.pkl' |
| if not os.path.isfile(OUT_FILE): |
| ix = 0 |
| for list_of_10 in harvard_individual_sentences[:10004]: |
| |
| |
| for text in list_of_10['sentences']: |
| if audio_prompt == 'english': |
| _p = synthetic_wav_paths[ix % len(synthetic_wav_paths)] |
| |
| style_vec = msinference.compute_style(_p) |
| elif audio_prompt == 'english_4x': |
| _p = synthetic_wav_paths_4x[ix % len(synthetic_wav_paths_4x)] |
| |
| style_vec = msinference.compute_style(_p) |
| elif audio_prompt == 'human': |
| _p = natural_wav_paths[ix % len(natural_wav_paths)] |
| |
| style_vec = msinference.compute_style(_p) |
| elif audio_prompt == 'foreign': |
| _p = synthetic_wav_paths_foreign[ix % len(synthetic_wav_paths_foreign)] |
| |
| style_vec = msinference.compute_style(_p) |
| elif audio_prompt == 'foreign_4x': |
| _p = synthetic_wav_paths_foreign_4x[ix % len(synthetic_wav_paths_foreign_4x)] |
| |
| style_vec = msinference.compute_style(_p) |
| else: |
| print('unknonw list of style vector') |
| |
| x = msinference.inference(text, |
| style_vec, |
| alpha=0.3, |
| beta=0.7, |
| diffusion_steps=7, |
| embedding_scale=1) |
| x = audresample.resample(x, 24000, 16000) |
| |
| |
| _st, fsr = audiofile.read(_p) |
| _st = audresample.resample(_st, fsr, 16000) |
| print(_st.shape, x.shape) |
| |
| emotion_of_prompt = process_function(_st, 16000, None) |
| emotion_of_out = process_function(x, 16000, None) |
| data[ix, :11] = emotion_of_prompt |
| data[ix, 11:22] = emotion_of_out |
| |
| |
| |
| transcription_prompt = _pipe(_st[0]) |
| transcription_styletts2 = _pipe(x[0]) |
| |
| print(transcription_prompt, transcription_styletts2) |
| |
| data[ix, 22] = jiwer.cer('Sweet dreams are made of this. I travel the world and the seven seas.', |
| transcription_prompt['text']) |
| |
| data[ix, 23] = jiwer.cer(text, |
| transcription_styletts2['text']) |
| print(data[ix, :]) |
| |
| ix += 1 |
| |
| df = pd.DataFrame(data, columns=['prompt-' + i for i in LABELS] + ['styletts2-' + i for i in LABELS] + ['cer-prompt', 'cer-styletts2']) |
| df.to_pickle(OUT_FILE) |
| else: |
| |
| df = pd.read_pickle(OUT_FILE) |
| print('\nALREADY EXISTS\n{df}') |
| |
|
|