Spaces:
Running
Running
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| import torch.nn as nn | |
| from transformers import Wav2Vec2Processor | |
| from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model | |
| from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel | |
| import audiofile | |
| class ModelHead(nn.Module): | |
| r"""Classification head.""" | |
| def __init__(self, config, num_labels): | |
| super().__init__() | |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |
| self.dropout = nn.Dropout(config.final_dropout) | |
| self.out_proj = nn.Linear(config.hidden_size, num_labels) | |
| def forward(self, features, **kwargs): | |
| x = features | |
| x = self.dropout(x) | |
| x = self.dense(x) | |
| x = torch.tanh(x) | |
| x = self.dropout(x) | |
| x = self.out_proj(x) | |
| return x | |
| class AgeGenderModel(Wav2Vec2PreTrainedModel): | |
| r"""Speech emotion classifier.""" | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.config = config | |
| self.wav2vec2 = Wav2Vec2Model(config) | |
| self.age = ModelHead(config, 1) | |
| self.gender = ModelHead(config, 3) | |
| self.init_weights() | |
| def forward( | |
| self, | |
| input_values, | |
| ): | |
| outputs = self.wav2vec2(input_values) | |
| hidden_states = outputs[0] | |
| hidden_states = torch.mean(hidden_states, dim=1) | |
| logits_age = self.age(hidden_states) | |
| logits_gender = torch.softmax(self.gender(hidden_states), dim=1) | |
| return hidden_states, logits_age, logits_gender | |
| # load model from hub | |
| device = 0 if torch.cuda.is_available() else "cpu" | |
| model_name = "audeering/wav2vec2-large-robust-24-ft-age-gender" | |
| processor = Wav2Vec2Processor.from_pretrained(model_name) | |
| model = AgeGenderModel.from_pretrained(model_name) | |
| def process_func(x: np.ndarray, sampling_rate: int) -> dict: | |
| r"""Predict age and gender or extract embeddings from raw audio signal.""" | |
| # run through processor to normalize signal | |
| # always returns a batch, so we just get the first entry | |
| # then we put it on the device | |
| y = processor(x, sampling_rate=sampling_rate) | |
| y = y['input_values'][0] | |
| y = y.reshape(1, -1) | |
| y = torch.from_numpy(y).to(device) | |
| # run through model | |
| with torch.no_grad(): | |
| y = model(y) | |
| y = torch.hstack([y[1], y[2]]) | |
| # convert to numpy | |
| y = y.detach().cpu().numpy() | |
| # convert to dict | |
| y = [ | |
| {"score": 100 * y[0][0], "label": "age"}, | |
| {"score": y[0][1], "label": "female"}, | |
| {"score": y[0][2], "label": "male"}, | |
| {"score": y[0][3], "label": "child"}, | |
| ] | |
| return y | |
| def recognize(file): | |
| if file is None: | |
| raise gr.Error( | |
| "No audio file submitted! " | |
| "Please upload or record an audio file " | |
| "before submitting your request." | |
| ) | |
| signal, sampling_rate = audiofile.read(file) | |
| age_gender = process_func(signal, sampling_rate) | |
| return age_gender | |
| demo = gr.Blocks() | |
| outputs = gr.outputs.Label() | |
| title = "audEERING age and gender recognition" | |
| description = ( | |
| "Recognize age and gender of a microphone recording or audio file. " | |
| "Demo uses the checkpoint [{model_name}](https://huggingface.co/{model_name})." | |
| ) | |
| allow_flagging = "never" | |
| microphone = gr.Interface( | |
| fn=recognize, | |
| inputs=gr.Audio(sources="microphone", type="filepath"), | |
| outputs=outputs, | |
| title=title, | |
| description=description, | |
| allow_flagging=allow_flagging, | |
| ) | |
| file = gr.Interface( | |
| fn=recognize, | |
| inputs=gr.Audio(sources="upload", type="filepath", label="Audio file"), | |
| outputs=outputs, | |
| title=title, | |
| description=description, | |
| allow_flagging=allow_flagging, | |
| ) | |
| with demo: | |
| gr.TabbedInterface([microphone, file], ["Microphone", "Audio file"]) | |
| demo.queue().launch() | |