--- license: cc-by-nc-sa-4.0 pipeline_tag: audio-classification tags: - wav2small - valence - arousal - dominance - speech - speech-emotion-recognition --- # Wav2Small2.0 - Arousal / Dominance / Valence Please note that this model is for research purpose only. A commercial [license](https://www.audeering.com/products/devaice/) can be acquired with audEERING. The model expects a raw audio signal 16KHz as input, and outputs: arousal, dominance valence in range [0, 1]. The model is created following the [Wav2Small paper](https://arxiv.org/abs/2408.13920) and has a total of 17K params. # How To ```python import torch import numpy as np import librosa from transformers import Wav2Vec2PreTrainedModel, PretrainedConfig from torch import nn signal = torch.from_numpy( librosa.load('test.wav', sr=16000)[0])[None, :] device = 'cpu' def _prenorm(x, attention_mask=None): '''wav2vec2''' if attention_mask is not None: N = attention_mask.sum(1, keepdim=True) # here attn msk is unprocessed just the original input x -= x.sum(1, keepdim=True) / N var = (x * x).sum(1, keepdim=True) / N else: x -= x.mean(1, keepdim=True) # mean is an onnx operator reducemean saves some ops compared to casting integer N to float and the div var = (x * x).mean(1, keepdim=True) return x / torch.sqrt(var + 1e-7) class Spectrogram(nn.Module): def __init__(self, n_fft=64, # num cols of DFT n_time=64, # num rows of DFT matrix hop_length=32, freeze_parameters=True): super().__init__() fft_window = librosa.filters.get_window('hann', n_time, fftbins=True) fft_window = librosa.util.pad_center(fft_window, size=n_time) out_channels = n_fft // 2 + 1 (x, y) = np.meshgrid(np.arange(n_time), np.arange(n_fft)) omega = np.exp(-2 * np.pi * 1j / n_time) dft_matrix = np.power(omega, x * y) # (n_fft, n_time) dft_matrix = dft_matrix * fft_window[None, :] dft_matrix = dft_matrix[0 : out_channels, :] dft_matrix = dft_matrix[:, None, :] # ---- Assymetric DFT Non Square self.conv_real = nn.Conv1d(1, out_channels, n_fft, stride=hop_length, padding=0, bias=False) self.conv_imag = nn.Conv1d(1, out_channels, n_fft, stride=hop_length, padding=0, bias=False) self.conv_real.weight.data = torch.tensor(np.real(dft_matrix), dtype=self.conv_real.weight.dtype, device=self.conv_real.weight.device) self.conv_imag.weight.data = torch.tensor(np.imag(dft_matrix), dtype=self.conv_imag.weight.dtype, device=self.conv_imag.weight.device) if freeze_parameters: for param in self.parameters(): param.requires_grad = False def forward(self, input): x = input[:, None, :] real = self.conv_real(x) imag = self.conv_imag(x) return real ** 2 + imag ** 2 # bs, freq, time-frames class LogmelFilterBank(nn.Module): def __init__(self, sr=16000, n_fft=64, n_mels=26, # maxpool fmin=0.0, freeze_parameters=True): super().__init__() fmax = sr//2 W2 = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax).T self.register_buffer('melW', torch.Tensor(W2)) self.register_buffer('amin', torch.Tensor([1e-10])) def forward(self, x): x = torch.matmul(x[:, None, :, :].transpose(2, 3), self.melW) # changes melf not num frames x = torch.where(x > self.amin, x, self.amin) # not in place x = 10 * torch.log10(x) return x class Conv(nn.Module): def __init__(self, c_in, c_out, k=3, stride=1, padding=1): super().__init__() self.conv = nn.Conv2d(c_in, c_out, k, stride=stride, padding=padding, bias=False) self.norm = nn.BatchNorm2d(c_out) def forward(self, x): x = self.conv(x) x = self.norm(x) return torch.relu_(x) class Vgg7(nn.Module): def __init__(self): super().__init__() self.l1 = Conv( 1, 13) self.l2 = Conv(13, 13) self.l3 = Conv(13, 13) self.maxpool_A = nn.MaxPool2d(3, stride=2, padding=1) self.l4 = Conv(13, 13) self.l5 = Conv(13, 13) self.l6 = Conv(13, 13) self.l7 = Conv(13, 13) self.lin = nn.Conv2d(13, 13, 1, padding=0, stride=1) self.sof = nn.Conv2d(13, 13, 1, padding=0, stride=1) self.spectrogram_extractor = Spectrogram() self.logmel_extractor = LogmelFilterBank() def forward(self, x, attention_mask=None): x = _prenorm(x, attention_mask=attention_mask) x = self.spectrogram_extractor(x) x = self.logmel_extractor(x) x = self.l1(x) x = self.l2(x) x = self.l3(x) x = self.maxpool_A(x) # reshape here? so these conv will have large kernel x = self.l4(x) x = self.l5(x) x = self.l6(x) x = self.l7(x) x = self.lin(x) * self.sof(x).softmax(2) # [bs, ch, time-frams, mel] x = x.sum(2) x = torch.cat([x, torch.bmm(x, x.transpose(1,2))], 2) # cosine over mel dims return x.reshape(-1, 338) class Wav2SmallConfig(PretrainedConfig): model_type = "wav2vec2" def __init__(self, **kwargs): super().__init__(**kwargs) self.half_mel = 13 self.n_fft = 64 self.n_time = 64 self.hidden = 2 * self.half_mel * self.half_mel self.hop = self.n_time // 2 class Wav2Small(Wav2Vec2PreTrainedModel): def __init__(self, config): super().__init__(config) self.vgg7 = Vgg7() self.adv = nn.Linear(config.hidden, 3) # 0=arousal, 1=dominance, 2=valence def forward(self, x, attention_mask=None): x = self.vgg7(x, attention_mask=attention_mask) return self.adv(x) model = Wav2Small.from_pretrained( 'audeering/wav2small').to(device).eval() with torch.no_grad(): logits = model(signal.to(device)) print(f'\nArousal={logits[:, 0]}\n', f'Dominance={logits[:, 1]}\n', f'Valence={logits[:, 2]}\n') ```