| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
| FULL_WAV = [ |
| 'english_hfullh.wav', |
| 'english_4x_hfullh.wav', |
| 'human_hfullh.wav', |
| 'foreign_hfullh.wav', |
| 'foreign_4x_hfullh.wav', |
| ] |
| WIN = 40 |
| HOP = 10 |
| import pandas as pd |
| import os |
|
|
| import json |
| import numpy as np |
| import audonnx |
| import audb |
| from pathlib import Path |
| import transformers |
| import torch |
| import audmodel |
| import audinterface |
| import matplotlib.pyplot as plt |
| import audiofile |
|
|
| LABELS = ['arousal', 'dominance', 'valence', |
| |
| 'Angry', |
| 'Sad', |
| 'Happy', |
| 'Surprise', |
| 'Fear', |
| 'Disgust', |
| 'Contempt', |
| 'Neutral' |
| ] |
|
|
|
|
| config = transformers.Wav2Vec2Config() |
| config.dev = torch.device('cuda:0') |
| config.dev2 = torch.device('cuda:0') |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| def _softmax(x): |
| '''x : (batch, num_class)''' |
| x -= x.max(1, keepdims=True) |
| x = np.maximum(-100, x) |
| x = np.exp(x) |
| x /= x.sum(1, keepdims=True) |
| return x |
|
|
| def _sigmoid(x): |
| '''x : (batch, num_class)''' |
| return 1 / (1 + np.exp(-x)) |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| for long_audio in FULL_WAV: |
| file_interface = f'timeseries_{long_audio.replace("/", "")}.pkl' |
| if not os.path.exists(file_interface): |
|
|
|
|
| print('_______________________________________\nProcessing\n', file_interface, '\n___________') |
|
|
|
|
|
|
| |
|
|
| from transformers import AutoModelForAudioClassification |
| import types |
| def _infer(self, x): |
| '''x: (batch, audio-samples-16KHz)''' |
| x = (x + self.config.mean) / self.config.std |
| x = self.ssl_model(x, attention_mask=None).last_hidden_state |
| |
| h = self.pool_model.sap_linear(x).tanh() |
| w = torch.matmul(h, self.pool_model.attention) |
| w = w.softmax(1) |
| mu = (x * w).sum(1) |
| x = torch.cat( |
| [ |
| mu, |
| ((x * x * w).sum(1) - mu * mu).clamp(min=1e-7).sqrt() |
| ], 1) |
| return self.ser_model(x) |
|
|
| teacher_cat = AutoModelForAudioClassification.from_pretrained( |
| '3loi/SER-Odyssey-Baseline-WavLM-Categorical-Attributes', |
| trust_remote_code=True |
| ).to(config.dev2).eval() |
| teacher_cat.forward = types.MethodType(_infer, teacher_cat) |
| |
|
|
| |
| def _prenorm(x, attention_mask=None): |
| '''mean/var''' |
| if attention_mask is not None: |
| N = attention_mask.sum(1, keepdim=True) |
| x -= x.sum(1, keepdim=True) / N |
| var = (x * x).sum(1, keepdim=True) / N |
|
|
| else: |
| x -= x.mean(1, keepdim=True) |
| var = (x * x).mean(1, keepdim=True) |
| return x / torch.sqrt(var + 1e-7) |
|
|
| from torch import nn |
| from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel, Wav2Vec2Model |
| class RegressionHead(nn.Module): |
| r"""Classification head.""" |
|
|
| def __init__(self, config): |
|
|
| super().__init__() |
|
|
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| self.dropout = nn.Dropout(config.final_dropout) |
| self.out_proj = nn.Linear(config.hidden_size, config.num_labels) |
|
|
| def forward(self, features, **kwargs): |
|
|
| x = features |
| x = self.dropout(x) |
| x = self.dense(x) |
| x = torch.tanh(x) |
| x = self.dropout(x) |
| x = self.out_proj(x) |
|
|
| return x |
|
|
|
|
| class Dawn(Wav2Vec2PreTrainedModel): |
| r"""Speech emotion classifier.""" |
|
|
| def __init__(self, config): |
|
|
| super().__init__(config) |
|
|
| self.config = config |
| self.wav2vec2 = Wav2Vec2Model(config) |
| self.classifier = RegressionHead(config) |
| self.init_weights() |
|
|
| def forward( |
| self, |
| input_values, |
| attention_mask=None, |
| ): |
| x = _prenorm(input_values, attention_mask=attention_mask) |
| outputs = self.wav2vec2(x, attention_mask=attention_mask) |
| hidden_states = outputs[0] |
| hidden_states = torch.mean(hidden_states, dim=1) |
| logits = self.classifier(hidden_states) |
| return logits |
| |
| |
| dawn = Dawn.from_pretrained('audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim').to(config.dev).eval() |
| |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def process_function(x, sampling_rate, idx): |
| '''run audioset ct, adv |
| |
| USE onnx teachers |
| |
| return [synth-speech, synth-singing, 7x, 3x adv] = 11 |
| ''' |
| |
| |
| |
| |
| |
| logits_cat = teacher_cat(torch.from_numpy(x).to(config.dev)).cpu().detach().numpy() |
| |
| |
| |
| |
| |
| logits_adv = dawn(torch.from_numpy(x).to(config.dev)).cpu().detach().numpy() |
| |
| cat = np.concatenate([logits_adv, |
| |
| _softmax(logits_cat)], |
| 1) |
| print(cat) |
| return cat |
|
|
|
|
| |
|
|
| |
| interface = audinterface.Feature( |
| feature_names=LABELS, |
| process_func=process_function, |
| |
| process_func_applies_sliding_window=False, |
| win_dur=WIN, |
| hop_dur=HOP, |
| sampling_rate=16000, |
| resample=True, |
| verbose=True, |
| ) |
| df_pred = interface.process_file(long_audio) |
| df_pred.to_pickle(file_interface) |
| else: |
| print(file_interface, 'FOUND') |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| preds = {} |
| SHORTEST_PD = 100000 |
| for long_audio in FULL_WAV: |
| file_interface = f'timeseries_{long_audio.replace("/", "")}.pkl' |
| y = pd.read_pickle(file_interface) |
| preds[long_audio] = y |
| SHORTEST_PD = min(SHORTEST_PD, len(y)) |
|
|
| |
|
|
| for k,v in preds.items(): |
| p = v[:SHORTEST_PD] |
| |
| p.reset_index(inplace= True) |
| p.drop(columns=['file','start'], inplace=True) |
| p.set_index('end', inplace=True) |
| |
| p.index = p.index.map(mapper = (lambda x: x.total_seconds())) |
| preds[k] = p |
|
|
| |
| |
| print(preds.keys(),'p') |
|
|
|
|
|
|
|
|
| |
|
|
| for lang in ['english', |
| 'foreign']: |
| |
|
|
| fig, ax = plt.subplots(nrows=8, ncols=2, figsize=(24,20.7), |
| gridspec_kw={'hspace': 0, 'wspace': .04}) |
|
|
|
|
| |
|
|
| time_stamp = preds['human_hfullh.wav'].index.to_numpy() |
| for j, dim in enumerate(['arousal', |
| 'dominance', |
| 'valence']): |
|
|
| |
|
|
| ax[j, 0].plot(time_stamp, preds[f'{lang}_hfullh.wav'][dim], |
| color=(0,104/255,139/255), |
| label='mean_1', |
| linewidth=2) |
| ax[j, 0].fill_between(time_stamp, |
|
|
| 0*preds[f'{lang}_hfullh.wav'][dim], |
| preds['human_hfullh.wav'][dim], |
|
|
| color=(.2,.2,.2), |
| alpha=0.244) |
| if j == 0: |
| if lang == 'english': |
| desc = 'English' |
| else: |
| desc = 'Non-English' |
| ax[j, 0].legend([f'StyleTTS2 using Mimic-3 {desc}', |
| f'StyleTTS2 uising EmoDB'], |
| prop={'size': 14}, |
| ) |
| ax[j, 0].set_ylabel(dim.lower(), color=(.4, .4, .4), fontsize=17) |
| |
| |
| ax[j, 0].set_ylim([1e-7, .9999]) |
| |
| |
| ax[j, 0].set_xticklabels(['' for _ in ax[j, 0].get_xticklabels()]) |
| ax[j, 0].set_xlim([time_stamp[0], time_stamp[-1]]) |
|
|
|
|
| |
|
|
|
|
| ax[j, 1].plot(time_stamp, preds[f'{lang}_4x_hfullh.wav'][dim], |
| color=(0,104/255,139/255), |
| label='mean_1', |
| linewidth=2) |
| ax[j, 1].fill_between(time_stamp, |
|
|
| 0 * preds[f'{lang}_4x_hfullh.wav'][dim], |
| preds['human_hfullh.wav'][dim], |
|
|
| color=(.2,.2,.2), |
| alpha=0.244) |
| if j == 0: |
| if lang == 'english': |
| desc = 'English' |
| else: |
| desc = 'Non-English' |
| ax[j, 1].legend([f'StyleTTS2 using Mimic-3 {desc} 4x speed', |
| f'StyleTTS2 using EmoDB'], |
| prop={'size': 14}, |
| |
| ) |
|
|
|
|
| ax[j, 1].set_xlabel('720 Harvard Sentences') |
|
|
|
|
|
|
| |
| ax[j, 1].set_ylim([1e-7, .9999]) |
| |
| ax[j, 1].set_xticklabels(['' for _ in ax[j, 0].get_xticklabels()]) |
| ax[j, 1].set_xlim([time_stamp[0], time_stamp[-1]]) |
|
|
|
|
|
|
|
|
| ax[j, 0].grid() |
| ax[j, 1].grid() |
| |
|
|
|
|
|
|
|
|
|
|
| time_stamp = preds['human_hfullh.wav'].index.to_numpy() |
| for j, dim in enumerate(['Angry', |
| 'Sad', |
| 'Happy', |
| |
| 'Fear', |
| 'Disgust', |
| |
| |
| ]): |
| j = j + 3 |
|
|
| |
|
|
| ax[j, 0].plot(time_stamp, preds[f'{lang}_hfullh.wav'][dim], |
| color=(0,104/255,139/255), |
| label='mean_1', |
| linewidth=2) |
| ax[j, 0].fill_between(time_stamp, |
|
|
| 0*preds[f'{lang}_hfullh.wav'][dim], |
| preds['human_hfullh.wav'][dim], |
|
|
| color=(.2,.2,.2), |
| alpha=0.244) |
| |
| |
| |
| |
| |
|
|
|
|
| ax[j, 0].set_ylabel(dim.lower(), color=(.4, .4, .4), fontsize=17) |
|
|
| |
| ax[j, 0].set_ylim([1e-7, .9999]) |
| ax[j, 0].set_xlim([time_stamp[0], time_stamp[-1]]) |
| ax[j, 0].set_xticklabels(['' for _ in ax[j, 0].get_xticklabels()]) |
| ax[j, 0].set_xlabel('720 Harvard Sentences', fontsize=17, color=(.2,.2,.2)) |
|
|
|
|
| |
|
|
|
|
| ax[j, 1].plot(time_stamp, preds[f'{lang}_4x_hfullh.wav'][dim], |
| color=(0,104/255,139/255), |
| label='mean_1', |
| linewidth=2) |
| ax[j, 1].fill_between(time_stamp, |
|
|
| 0*preds[f'{lang}_4x_hfullh.wav'][dim], |
| preds['human_hfullh.wav'][dim], |
|
|
| color=(.2,.2,.2), |
| alpha=0.244) |
| |
| |
| |
| |
| |
| ax[j, 1].set_xlabel('720 Harvard Sentences', fontsize=17, color=(.2,.2,.2)) |
| ax[j, 1].set_ylim([1e-7, .9999]) |
| |
| ax[j, 1].set_xticklabels(['' for _ in ax[j, 1].get_xticklabels()]) |
| ax[j, 1].set_xlim([time_stamp[0], time_stamp[-1]]) |
| |
|
|
|
|
|
|
|
|
|
|
| ax[j, 0].grid() |
| ax[j, 1].grid() |
|
|
|
|
|
|
| plt.savefig(f'fig_{lang}_{WIN=}_{HOP=}_HFdisc.png', bbox_inches='tight') |
| plt.close() |
|
|
|
|