Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| import typing | |
| import types # fusion of forward() of Wav2Vec2 | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| from transformers import Wav2Vec2Processor | |
| from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model | |
| from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel | |
| import audiofile | |
| import unicodedata | |
| import textwrap | |
| from tts import StyleTTS2 | |
| import audresample | |
| device = 0 if torch.cuda.is_available() else "cpu" | |
| duration = 2 # limit processing of audio | |
| age_gender_model_name = "audeering/wav2vec2-large-robust-6-ft-age-gender" | |
| expression_model_name = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim" | |
| class AgeGenderHead(nn.Module): | |
| r"""Age-gender model head.""" | |
| def __init__(self, config, num_labels): | |
| super().__init__() | |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |
| self.dropout = nn.Dropout(config.final_dropout) | |
| self.out_proj = nn.Linear(config.hidden_size, num_labels) | |
| def forward(self, features, **kwargs): | |
| x = features | |
| x = self.dropout(x) | |
| x = self.dense(x) | |
| x = torch.tanh(x) | |
| x = self.dropout(x) | |
| x = self.out_proj(x) | |
| return x | |
| class AgeGenderModel(Wav2Vec2PreTrainedModel): | |
| r"""Age-gender recognition model.""" | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.config = config | |
| self.wav2vec2 = Wav2Vec2Model(config) | |
| self.age = AgeGenderHead(config, 1) | |
| self.gender = AgeGenderHead(config, 3) | |
| self.init_weights() | |
| def forward( | |
| self, | |
| frozen_cnn7, | |
| ): | |
| hidden_states = self.wav2vec2(frozen_cnn7=frozen_cnn7) # runs only Transformer layers | |
| hidden_states = torch.mean(hidden_states, dim=1) | |
| logits_age = self.age(hidden_states) | |
| logits_gender = torch.softmax(self.gender(hidden_states), dim=1) | |
| return hidden_states, logits_age, logits_gender | |
| # AgeGenderModel.forward() is switched to accept computed frozen CNN7 features from ExpressioNmodel | |
| def _forward( | |
| self, | |
| frozen_cnn7=None, # CNN7 fetures of wav2vec2 calc. from CNN7 feature extractor (once) | |
| attention_mask=None): | |
| if attention_mask is not None: | |
| # compute reduced attention_mask corresponding to feature vectors | |
| attention_mask = self._get_feature_vector_attention_mask( | |
| frozen_cnn7.shape[1], attention_mask, add_adapter=False | |
| ) | |
| hidden_states, _ = self.wav2vec2.feature_projection(frozen_cnn7) | |
| hidden_states = self.wav2vec2.encoder( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| )[0] | |
| return hidden_states | |
| def _forward_and_cnn7( | |
| self, | |
| input_values, | |
| attention_mask=None): | |
| frozen_cnn7 = self.wav2vec2.feature_extractor(input_values) | |
| frozen_cnn7 = frozen_cnn7.transpose(1, 2) | |
| if attention_mask is not None: | |
| # compute reduced attention_mask corresponding to feature vectors | |
| attention_mask = self.wav2vec2._get_feature_vector_attention_mask( | |
| frozen_cnn7.shape[1], attention_mask, add_adapter=False | |
| ) | |
| hidden_states, _ = self.wav2vec2.feature_projection(frozen_cnn7) # grad=True non frozen | |
| hidden_states = self.wav2vec2.encoder( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| )[0] | |
| return hidden_states, frozen_cnn7 #feature_proj is trainable thus we have to access the frozen_cnn7 before projection layer | |
| class ExpressionHead(nn.Module): | |
| r"""Expression model head.""" | |
| def __init__(self, config): | |
| super().__init__() | |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |
| self.dropout = nn.Dropout(config.final_dropout) | |
| self.out_proj = nn.Linear(config.hidden_size, config.num_labels) | |
| def forward(self, features, **kwargs): | |
| x = features | |
| x = self.dropout(x) | |
| x = self.dense(x) | |
| x = torch.tanh(x) | |
| x = self.dropout(x) | |
| x = self.out_proj(x) | |
| return x | |
| class ExpressionModel(Wav2Vec2PreTrainedModel): | |
| r"""speech expression model.""" | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.config = config | |
| self.wav2vec2 = Wav2Vec2Model(config) | |
| self.classifier = ExpressionHead(config) | |
| self.init_weights() | |
| def forward(self, input_values): | |
| hidden_states, frozen_cnn7 = self.wav2vec2(input_values) | |
| hidden_states = torch.mean(hidden_states, dim=1) | |
| logits = self.classifier(hidden_states) | |
| return hidden_states, logits, frozen_cnn7 | |
| # Load models from hub | |
| age_gender_model = AgeGenderModel.from_pretrained(age_gender_model_name) | |
| expression_processor = Wav2Vec2Processor.from_pretrained(expression_model_name) | |
| expression_model = ExpressionModel.from_pretrained(expression_model_name) | |
| # Emotion Calc. CNN features | |
| age_gender_model.wav2vec2.forward = types.MethodType(_forward, age_gender_model) | |
| expression_model.wav2vec2.forward = types.MethodType(_forward_and_cnn7, expression_model) | |
| def process_func(x: np.ndarray, sampling_rate: int) -> typing.Tuple[str, dict, str]: | |
| # batch audio | |
| y = expression_processor(x, sampling_rate=sampling_rate) | |
| y = y['input_values'][0] | |
| y = y.reshape(1, -1) | |
| y = torch.from_numpy(y).to(device) | |
| # run through expression model | |
| with torch.no_grad(): | |
| _, logits_expression, frozen_cnn7 = expression_model(y) | |
| _, logits_age, logits_gender = age_gender_model(frozen_cnn7=frozen_cnn7) | |
| # Plot A/D/V values | |
| plot_expression(logits_expression[0, 0].item(), # implicit detach().cpu().numpy() | |
| logits_expression[0, 1].item(), | |
| logits_expression[0, 2].item()) | |
| expression_file = "expression.png" | |
| plt.savefig(expression_file) | |
| return ( | |
| f"{round(100 * logits_age[0, 0].item())} years", # age | |
| { | |
| "female": logits_gender[0, 0].item(), | |
| "male": logits_gender[0, 1].item(), | |
| "child": logits_gender[0, 2].item(), | |
| }, | |
| expression_file, | |
| ) | |
| def recognize(input_file): | |
| if input_file is None: | |
| raise gr.Error( | |
| "No audio file submitted! " | |
| "Please upload or record an audio file " | |
| "before submitting your request." | |
| ) | |
| signal, sampling_rate = audiofile.read(input_file, duration=duration) | |
| # Resample to sampling rate supported byu the models | |
| target_rate = 16000 | |
| signal = audresample.resample(signal, sampling_rate, target_rate) | |
| return process_func(signal, target_rate) | |
| def explode(data): | |
| """ | |
| Expands a 3D array by creating gaps between voxels. | |
| This function is used to create the visual separation between the voxels. | |
| """ | |
| shape_orig = np.array(data.shape) | |
| shape_new = shape_orig * 2 - 1 | |
| retval = np.zeros(shape_new, dtype=data.dtype) | |
| retval[::2, ::2, ::2] = data | |
| return retval | |
| def explode(data): | |
| """ | |
| Expands a 3D array by adding new voxels between existing ones. | |
| This is used to create the gaps in the 3D plot. | |
| """ | |
| shape = data.shape | |
| new_shape = (2 * shape[0] - 1, 2 * shape[1] - 1, 2 * shape[2] - 1) | |
| new_data = np.zeros(new_shape, dtype=data.dtype) | |
| new_data[::2, ::2, ::2] = data | |
| return new_data | |
| def plot_expression(arousal, dominance, valence): | |
| '''_h = cuda tensor (N_PIX, N_PIX, N_PIX)''' | |
| N_PIX = 5 | |
| _h = np.random.rand(N_PIX, N_PIX, N_PIX) * 1e-3 | |
| adv = np.array([arousal, .994 - dominance, valence]).clip(0, .99) | |
| arousal, dominance, valence = (adv * N_PIX).astype(np.int64) # find voxel | |
| _h[arousal, dominance, valence] = .22 | |
| filled = np.ones((N_PIX, N_PIX, N_PIX), dtype=bool) | |
| # upscale the above voxel image, leaving gaps | |
| filled_2 = explode(filled) | |
| # Shrink the gaps | |
| x, y, z = np.indices(np.array(filled_2.shape) + 1).astype(float) // 2 | |
| x[1::2, :, :] += 1 | |
| y[:, 1::2, :] += 1 | |
| z[:, :, 1::2] += 1 | |
| fig = plt.figure() | |
| ax = fig.add_subplot(projection='3d') | |
| f_2 = np.ones([2 * N_PIX - 1, | |
| 2 * N_PIX - 1, | |
| 2 * N_PIX - 1, 4], dtype=np.float64) | |
| f_2[:, :, :, 3] = explode(_h) | |
| cm = plt.get_cmap('cool') | |
| f_2[:, :, :, :3] = cm(f_2[:, :, :, 3])[..., :3] | |
| f_2[:, :, :, 3] = f_2[:, :, :, 3].clip(.01, .74) | |
| ecolors_2 = f_2 | |
| ax.voxels(x, y, z, filled_2, facecolors=f_2, edgecolors=.006 * ecolors_2) | |
| ax.set_aspect('equal') | |
| ax.set_zticks([0, N_PIX]) | |
| ax.set_xticks([0, N_PIX]) | |
| ax.set_yticks([0, N_PIX]) | |
| ax.set_zticklabels([f'{n/N_PIX:.2f}'[0:] for n in ax.get_zticks()]) | |
| ax.set_zlabel('valence', fontsize=10, labelpad=0) | |
| ax.set_xticklabels([f'{n/N_PIX:.2f}' for n in ax.get_xticks()]) | |
| ax.set_xlabel('arousal', fontsize=10, labelpad=7) | |
| # The y-axis rotation is corrected here from 275 to 90 degrees | |
| ax.set_yticklabels([f'{1-n/N_PIX:.2f}' for n in ax.get_yticks()], rotation=90) | |
| ax.set_ylabel('dominance', fontsize=10, labelpad=10) | |
| ax.grid(False) | |
| ax.plot([N_PIX, N_PIX], [0, N_PIX + .2], [N_PIX, N_PIX], 'g', linewidth=1) | |
| ax.plot([0, N_PIX], [N_PIX, N_PIX + .24], [N_PIX, N_PIX], 'k', linewidth=1) | |
| # Missing lines on the top face | |
| ax.plot([0, 0], [0, N_PIX], [N_PIX, N_PIX], 'darkred', linewidth=1) | |
| ax.plot([0, N_PIX], [0, 0], [N_PIX, N_PIX], 'darkblue', linewidth=1) | |
| # Set pane colors after plotting the lines | |
| # UPDATED: Replaced `w_xaxis` with `xaxis` and `w_yaxis` with `yaxis`. | |
| ax.xaxis.set_pane_color((0.8, 0.8, 0.8, 0.5)) | |
| ax.yaxis.set_pane_color((0.8, 0.8, 0.8, 0.5)) | |
| ax.zaxis.set_pane_color((0.8, 0.8, 0.8, 0.0)) | |
| # Restore the limits to prevent the plot from expanding | |
| ax.set_xlim(0, N_PIX) | |
| ax.set_ylim(0, N_PIX) | |
| ax.set_zlim(0, N_PIX) | |
| # plt.show() | |
| # TTS | |
| VOICES = [f'wav/{vox}' for vox in os.listdir('wav')] | |
| _tts = StyleTTS2().to('cpu') | |
| def only_greek_or_only_latin(text, lang='grc'): | |
| ''' | |
| str: The converted string in the specified target script. | |
| Characters not found in any mapping are preserved as is. | |
| Latin accented characters in the input (e.g., 'É', 'ü') will | |
| be preserved in their lowercase form (e.g., 'é', 'ü') if | |
| converting to Latin. | |
| ''' | |
| # --- Mapping Dictionaries --- | |
| # Keys are in lowercase as input text is case-folded. | |
| # If the output needs to maintain original casing, additional logic is required. | |
| latin_to_greek_map = { | |
| 'a': 'α', 'b': 'β', 'g': 'γ', 'd': 'δ', 'e': 'ε', | |
| 'ch': 'τσο', # Example of a multi-character Latin sequence | |
| 'z': 'ζ', 'h': 'χ', 'i': 'ι', 'k': 'κ', 'l': 'λ', | |
| 'm': 'μ', 'n': 'ν', 'x': 'ξ', 'o': 'ο', 'p': 'π', | |
| 'v': 'β', 'sc': 'σκ', 'r': 'ρ', 's': 'σ', 't': 'τ', | |
| 'u': 'ου', 'f': 'φ', 'c': 'σ', 'w': 'β', 'y': 'γ', | |
| } | |
| greek_to_latin_map = { | |
| 'ου': 'ou', # Prioritize common diphthongs/digraphs | |
| 'α': 'a', 'β': 'v', 'γ': 'g', 'δ': 'd', 'ε': 'e', | |
| 'ζ': 'z', 'η': 'i', 'θ': 'th', 'ι': 'i', 'κ': 'k', | |
| 'λ': 'l', 'μ': 'm', 'ν': 'n', 'ξ': 'x', 'ο': 'o', | |
| 'π': 'p', 'ρ': 'r', 'σ': 's', 'τ': 't', 'υ': 'y', # 'y' is a common transliteration for upsilon | |
| 'φ': 'f', 'χ': 'ch', 'ψ': 'ps', 'ω': 'o', | |
| 'ς': 's', # Final sigma | |
| } | |
| cyrillic_to_latin_map = { | |
| 'а': 'a', 'б': 'b', 'в': 'v', 'г': 'g', 'д': 'd', 'е': 'e', 'ё': 'yo', 'ж': 'zh', | |
| 'з': 'z', 'и': 'i', 'й': 'y', 'к': 'k', 'л': 'l', 'м': 'm', 'н': 'n', 'о': 'o', | |
| 'п': 'p', 'р': 'r', 'с': 's', 'т': 't', 'у': 'u', 'ф': 'f', 'х': 'kh', 'ц': 'ts', | |
| 'ч': 'ch', 'ш': 'sh', 'щ': 'shch', 'ъ': '', 'ы': 'y', 'ь': '', 'э': 'e', 'ю': 'yu', | |
| 'я': 'ya', | |
| } | |
| # Direct Cyrillic to Greek mapping based on phonetic similarity. | |
| # These are approximations and may not be universally accepted transliterations. | |
| cyrillic_to_greek_map = { | |
| 'а': 'α', 'б': 'β', 'в': 'β', 'г': 'γ', 'д': 'δ', 'е': 'ε', 'ё': 'ιο', 'ж': 'ζ', | |
| 'з': 'ζ', 'и': 'ι', 'й': 'ι', 'κ': 'κ', 'λ': 'λ', 'м': 'μ', 'н': 'ν', 'о': 'ο', | |
| 'π': 'π', 'ρ': 'ρ', 'σ': 'σ', 'τ': 'τ', 'у': 'ου', 'ф': 'φ', 'х': 'χ', 'ц': 'τσ', | |
| 'ч': 'τσ', # or τζ depending on desired sound | |
| 'ш': 'σ', 'щ': 'σ', # approximations | |
| 'ъ': '', 'ы': 'ι', 'ь': '', 'э': 'ε', 'ю': 'ιου', | |
| 'я': 'ια', | |
| } | |
| # Convert the input text to lowercase, preserving accents for Latin characters. | |
| # casefold() is used for more robust caseless matching across Unicode characters. | |
| lowercased_text = text.lower() #casefold() | |
| output_chars = [] | |
| current_index = 0 | |
| if lang == 'grc': | |
| # Combine all relevant maps for direct lookup to Greek | |
| conversion_map = {**latin_to_greek_map, **cyrillic_to_greek_map} | |
| # Sort keys by length in reverse order to handle multi-character sequences first | |
| sorted_source_keys = sorted( | |
| list(latin_to_greek_map.keys()) + list(cyrillic_to_greek_map.keys()), | |
| key=len, | |
| reverse=True | |
| ) | |
| while current_index < len(lowercased_text): | |
| found_conversion = False | |
| for key in sorted_source_keys: | |
| if lowercased_text.startswith(key, current_index): | |
| output_chars.append(conversion_map[key]) | |
| current_index += len(key) | |
| found_conversion = True | |
| break | |
| if not found_conversion: | |
| # If no specific mapping found, append the character as is. | |
| # This handles unmapped characters and already Greek characters. | |
| output_chars.append(lowercased_text[current_index]) | |
| current_index += 1 | |
| return ''.join(output_chars) | |
| else: # Default to 'lat' conversion | |
| # Combine Greek to Latin and Cyrillic to Latin maps. | |
| # Cyrillic map keys will take precedence in case of overlap if defined after Greek. | |
| combined_to_latin_map = {**greek_to_latin_map, **cyrillic_to_latin_map} | |
| # Sort all relevant source keys by length in reverse for replacement | |
| sorted_source_keys = sorted( | |
| list(greek_to_latin_map.keys()) + list(cyrillic_to_latin_map.keys()), | |
| key=len, | |
| reverse=True | |
| ) | |
| while current_index < len(lowercased_text): | |
| found_conversion = False | |
| for key in sorted_source_keys: | |
| if lowercased_text.startswith(key, current_index): | |
| latin_equivalent = combined_to_latin_map[key] | |
| # Strip accents ONLY if the source character was from the Greek map. | |
| # This preserves accents on original Latin characters (like 'é') | |
| # and allows for intentional accent stripping from Greek transliterations. | |
| if key in greek_to_latin_map: | |
| normalized_latin = unicodedata.normalize('NFD', latin_equivalent) | |
| stripped_latin = ''.join(c for c in normalized_latin if not unicodedata.combining(c)) | |
| output_chars.append(stripped_latin) | |
| else: | |
| output_chars.append(latin_equivalent) | |
| current_index += len(key) | |
| found_conversion = True | |
| break | |
| if not found_conversion: | |
| # If no conversion happened from Greek or Cyrillic, append the character as is. | |
| # This preserves existing Latin characters (including accented ones from input), | |
| # numbers, punctuation, and other symbols. | |
| output_chars.append(lowercased_text[current_index]) | |
| current_index += 1 | |
| return ''.join(output_chars) | |
| def other_tts(text='Hallov worlds Far over the', | |
| ref_s='wav/af_ZA_google-nwu_0184.wav'): | |
| text = only_greek_or_only_latin(text, lang='eng') | |
| x = _tts.inference(text, ref_s=ref_s)[0, 0, :].cpu().numpy() | |
| # x /= np.abs(x).max() + 1e-7 ~ Volume normalisation @api.py:tts_multi_sentence() OR demo.py | |
| tmp_file = f'_speech.wav' # N x clients (cleanup vs tmp file / client) | |
| audiofile.write(tmp_file, x, 24000) | |
| return tmp_file | |
| def update_selected_voice(voice_filename): | |
| return 'wav/' + voice_filename + '.wav' | |
| description = ( | |
| "Estimate **age**, **gender**, and **expression** " | |
| "of the speaker contained in an audio file or microphone recording. \n" | |
| f"The model [{age_gender_model_name}]" | |
| f"(https://huggingface.co/{age_gender_model_name}) " | |
| "recognises age and gender, " | |
| f"whereas [{expression_model_name}]" | |
| f"(https://huggingface.co/{expression_model_name}) " | |
| "recognises the expression dimensions arousal, dominance, and valence. " | |
| ) | |
| with gr.Blocks() as demo: | |
| with gr.Tab(label="other TTS"): | |
| selected_voice = gr.State(value='wav/en_US_m-ailabs_mary_ann.wav') | |
| with gr.Row(): | |
| voice_info = gr.Markdown(f'TTS vox : `{selected_voice.value}`') | |
| # Main input and output components | |
| with gr.Row(): | |
| text_input = gr.Textbox( | |
| label="Enter text for TTS:", | |
| placeholder="Type your message here...", | |
| lines=4, | |
| value="Farover the misty mountains cold too dungeons deep and caverns old.", | |
| ) | |
| generate_button = gr.Button("Generate Audio", variant="primary") | |
| output_audio = gr.Audio(label="TTS Output") | |
| with gr.Column(): | |
| voice_buttons = [] | |
| for i in range(0, len(VOICES), 7): | |
| with gr.Row(): | |
| for voice_filename in VOICES[i:i+7]: | |
| voice_filename = voice_filename[4:-4] # drop wav/ for visibility | |
| button = gr.Button(voice_filename) | |
| button.click( | |
| fn=update_selected_voice, | |
| inputs=[gr.Textbox(value=voice_filename, visible=False)], | |
| outputs=[selected_voice] | |
| ) | |
| button.click( | |
| fn=lambda v=voice_filename: f"TTS Vox = `{v}`", | |
| inputs=None, | |
| outputs=voice_info | |
| ) | |
| voice_buttons.append(button) | |
| generate_button.click( | |
| fn=other_tts, | |
| inputs=[text_input, selected_voice], | |
| outputs=output_audio | |
| ) | |
| with gr.Tab(label="Speech Analysis"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown(description) | |
| input = gr.Audio( | |
| sources=["upload", "microphone"], | |
| type="filepath", | |
| label="Audio input", | |
| min_length=0.025, # seconds | |
| ) | |
| gr.Examples( | |
| [ | |
| "wav/female-46-neutral.wav", | |
| "wav/female-20-happy.wav", | |
| "wav/male-60-angry.wav", | |
| "wav/male-27-sad.wav", | |
| ], | |
| [input], | |
| label="Examples from CREMA-D, ODbL v1.0 license", | |
| ) | |
| gr.Markdown("Only the first two seconds of the audio will be processed.") | |
| submit_btn = gr.Button(value="Submit") | |
| with gr.Column(): | |
| output_age = gr.Textbox(label="Age") | |
| output_gender = gr.Label(label="Gender") | |
| output_expression = gr.Image(label="Expression") | |
| outputs = [output_age, output_gender, output_expression] | |
| submit_btn.click(recognize, input, outputs) | |
| demo.launch(debug=True) | |