|
|
|
|
|
import typing |
|
|
import types |
|
|
import gradio as gr |
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
import os |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import Wav2Vec2Processor |
|
|
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model |
|
|
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel |
|
|
import audiofile |
|
|
from tts import StyleTTS2 |
|
|
import audresample |
|
|
import json |
|
|
import re |
|
|
import unicodedata |
|
|
import textwrap |
|
|
import nltk |
|
|
from num2words import num2words |
|
|
from num2word_greek.numbers2words import convert_numbers |
|
|
from audionar import VitsModel, VitsTokenizer |
|
|
|
|
|
nltk.download('punkt', download_dir='./') |
|
|
nltk.download('punkt_tab', download_dir='./') |
|
|
nltk.data.path.append('.') |
|
|
|
|
|
device = 'cpu' |
|
|
|
|
|
|
|
|
def fix_vocals(text, lang='ron'): |
|
|
|
|
|
|
|
|
|
|
|
ron_replacements = { |
|
|
'ţ': 'ț', |
|
|
'ț': 'ts', |
|
|
'î': 'u', |
|
|
'â': 'a', |
|
|
'ş': 's', |
|
|
'w': 'oui', |
|
|
'k': 'c', |
|
|
'l': 'll', |
|
|
|
|
|
'sqrt': ' rădăcina pătrată din ', |
|
|
'^': ' la puterea ', |
|
|
'+': ' plus ', |
|
|
' - ': ' minus ', |
|
|
'*': ' ori ', |
|
|
'/': ' împărțit la ', |
|
|
'=': ' egal cu ', |
|
|
'pi': ' pi ', |
|
|
'<': ' mai mic decât ', |
|
|
'>': ' mai mare decât', |
|
|
'%': ' la sută ', |
|
|
'(': ' paranteză deschisă ', |
|
|
')': ' paranteză închisă ', |
|
|
'[': ' paranteză pătrată deschisă ', |
|
|
']': ' paranteză pătrată închisă ', |
|
|
'{': ' acoladă deschisă ', |
|
|
'}': ' acoladă închisă ', |
|
|
'≠': ' nu este egal cu ', |
|
|
'≤': ' mai mic sau egal cu ', |
|
|
'≥': ' mai mare sau egal cu ', |
|
|
'≈': ' aproximativ ', |
|
|
'∞': ' infinit ', |
|
|
'€': ' euro ', |
|
|
'$': ' dolar ', |
|
|
'£': ' liră ', |
|
|
'&': ' și ', |
|
|
'@': ' la ', |
|
|
'#': ' diez ', |
|
|
'∑': ' sumă ', |
|
|
'∫': ' integrală ', |
|
|
'√': ' rădăcina pătrată a ', |
|
|
} |
|
|
|
|
|
eng_replacements = { |
|
|
'wik': 'weaky', |
|
|
'sh': 'ss', |
|
|
'ch': 'ttss', |
|
|
'oo': 'oeo', |
|
|
|
|
|
'sqrt': ' square root of ', |
|
|
'^': ' to the power of ', |
|
|
'+': ' plus ', |
|
|
' - ': ' minus ', |
|
|
'*': ' times ', |
|
|
' / ': ' divided by ', |
|
|
'=': ' equals ', |
|
|
'pi': ' pi ', |
|
|
'<': ' less than ', |
|
|
'>': ' greater than ', |
|
|
|
|
|
'%': ' percent ', |
|
|
'(': ' open parenthesis ', |
|
|
')': ' close parenthesis ', |
|
|
'[': ' open bracket ', |
|
|
']': ' close bracket ', |
|
|
'{': ' open curly brace ', |
|
|
'}': ' close curly brace ', |
|
|
'∑': ' sum ', |
|
|
'∫': ' integral ', |
|
|
'√': ' square root of ', |
|
|
'≠': ' not equals ', |
|
|
'≤': ' less than or equals ', |
|
|
'≥': ' greater than or equals ', |
|
|
'≈': ' approximately ', |
|
|
'∞': ' infinity ', |
|
|
'€': ' euro ', |
|
|
'$': ' dollar ', |
|
|
'£': ' pound ', |
|
|
'&': ' and ', |
|
|
'@': ' at ', |
|
|
'#': ' hash ', |
|
|
} |
|
|
|
|
|
serbian_replacements = { |
|
|
'rn': 'rrn', |
|
|
'ć': 'č', |
|
|
'c': 'č', |
|
|
'đ': 'd', |
|
|
'j': 'i', |
|
|
'l': 'lll', |
|
|
'w': 'v', |
|
|
|
|
|
'sqrt': 'kvadratni koren iz', |
|
|
'^': ' na stepen ', |
|
|
'+': ' plus ', |
|
|
' - ': ' minus ', |
|
|
'*': ' puta ', |
|
|
' / ': ' podeljeno sa ', |
|
|
'=': ' jednako ', |
|
|
'pi': ' pi ', |
|
|
'<': ' manje od ', |
|
|
'>': ' veće od ', |
|
|
'%': ' procenat ', |
|
|
'(': ' otvorena zagrada ', |
|
|
')': ' zatvorena zagrada ', |
|
|
'[': ' otvorena uglasta zagrada ', |
|
|
']': ' zatvorena uglasta zagrada ', |
|
|
'{': ' otvorena vitičasta zagrada ', |
|
|
'}': ' zatvorena vitičasta zagrada ', |
|
|
'∑': ' suma ', |
|
|
'∫': ' integral ', |
|
|
'√': ' kvadratni koren ', |
|
|
'≠': ' nije jednako ', |
|
|
'≤': ' manje ili jednako od ', |
|
|
'≥': ' veće ili jednako od ', |
|
|
'≈': ' približno ', |
|
|
'∞': ' beskonačnost ', |
|
|
'€': ' evro ', |
|
|
'$': ' dolar ', |
|
|
'£': ' funta ', |
|
|
'&': ' i ', |
|
|
'@': ' et ', |
|
|
'#': ' taraba ', |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
deu_replacements = { |
|
|
'sch': 'sh', |
|
|
'ch': 'kh', |
|
|
'ie': 'ee', |
|
|
'ei': 'ai', |
|
|
'ä': 'ae', |
|
|
'ö': 'oe', |
|
|
'ü': 'ue', |
|
|
'ß': 'ss', |
|
|
|
|
|
'sqrt': ' Quadratwurzel aus ', |
|
|
'^': ' hoch ', |
|
|
'+': ' plus ', |
|
|
' - ': ' minus ', |
|
|
'*': ' mal ', |
|
|
' / ': ' geteilt durch ', |
|
|
'=': ' gleich ', |
|
|
'pi': ' pi ', |
|
|
'<': ' kleiner als ', |
|
|
'>': ' größer als', |
|
|
|
|
|
'%': ' prozent ', |
|
|
'(': ' Klammer auf ', |
|
|
')': ' Klammer zu ', |
|
|
'[': ' eckige Klammer auf ', |
|
|
']': ' eckige Klammer zu ', |
|
|
'{': ' geschweifte Klammer auf ', |
|
|
'}': ' geschweifte Klammer zu ', |
|
|
'∑': ' Summe ', |
|
|
'∫': ' Integral ', |
|
|
'√': ' Quadratwurzel ', |
|
|
'≠': ' ungleich ', |
|
|
'≤': ' kleiner oder gleich ', |
|
|
'≥': ' größer oder gleich ', |
|
|
'≈': ' ungefähr ', |
|
|
'∞': ' unendlich ', |
|
|
'€': ' euro ', |
|
|
'$': ' dollar ', |
|
|
'£': ' pfund ', |
|
|
'&': ' und ', |
|
|
'@': ' at ', |
|
|
'#': ' raute ', |
|
|
} |
|
|
|
|
|
fra_replacements = { |
|
|
|
|
|
|
|
|
'w': 'v', |
|
|
|
|
|
'sqrt': ' racine carrée de ', |
|
|
'^': ' à la puissance ', |
|
|
'+': ' plus ', |
|
|
' - ': ' moins ', |
|
|
'*': ' fois ', |
|
|
' / ': ' divisé par ', |
|
|
'=': ' égale ', |
|
|
'pi': ' pi ', |
|
|
'<': ' inférieur à ', |
|
|
'>': ' supérieur à ', |
|
|
|
|
|
'%': ' pour cent ', |
|
|
'(': ' parenthèse ouverte ', |
|
|
')': ' parenthèse fermée ', |
|
|
'[': ' crochet ouvert ', |
|
|
']': ' crochet fermé ', |
|
|
'{': ' accolade ouverte ', |
|
|
'}': ' accolade fermée ', |
|
|
'∑': ' somme ', |
|
|
'∫': ' intégrale ', |
|
|
'√': ' racine carrée ', |
|
|
'≠': ' n\'égale pas ', |
|
|
'≤': ' inférieur ou égal à ', |
|
|
'≥': ' supérieur ou égal à ', |
|
|
'≈': ' approximativement ', |
|
|
'∞': ' infini ', |
|
|
'€': ' euro ', |
|
|
'$': ' dollar ', |
|
|
'£': ' livre ', |
|
|
'&': ' et ', |
|
|
'@': ' arobase ', |
|
|
'#': ' dièse ', |
|
|
} |
|
|
|
|
|
hun_replacements = { |
|
|
|
|
|
|
|
|
'ch': 'ts', |
|
|
'cs': 'tz', |
|
|
'g': 'gk', |
|
|
'w': 'v', |
|
|
'z': 'zz', |
|
|
|
|
|
'sqrt': ' négyzetgyök ', |
|
|
'^': ' hatvány ', |
|
|
'+': ' plusz ', |
|
|
' - ': ' mínusz ', |
|
|
'*': ' szorozva ', |
|
|
' / ': ' osztva ', |
|
|
'=': ' egyenlő ', |
|
|
'pi': ' pi ', |
|
|
'<': ' kisebb mint ', |
|
|
'>': ' nagyobb mint ', |
|
|
|
|
|
'%': ' százalék ', |
|
|
'(': ' nyitó zárójel ', |
|
|
')': ' záró zárójel ', |
|
|
'[': ' nyitó szögletes zárójel ', |
|
|
']': ' záró szögletes zárójel ', |
|
|
'{': ' nyitó kapcsos zárójel ', |
|
|
'}': ' záró kapcsos zárójel ', |
|
|
'∑': ' szumma ', |
|
|
'∫': ' integrál ', |
|
|
'√': ' négyzetgyök ', |
|
|
'≠': ' nem egyenlő ', |
|
|
'≤': ' kisebb vagy egyenlő ', |
|
|
'≥': ' nagyobb vagy egyenlő ', |
|
|
'≈': ' körülbelül ', |
|
|
'∞': ' végtelen ', |
|
|
'€': ' euró ', |
|
|
'$': ' dollár ', |
|
|
'£': ' font ', |
|
|
'&': ' és ', |
|
|
'@': ' kukac ', |
|
|
'#': ' kettőskereszt ', |
|
|
} |
|
|
|
|
|
grc_replacements = { |
|
|
|
|
|
|
|
|
|
|
|
'sqrt': ' τετραγωνικὴ ῥίζα ', |
|
|
'^': ' εἰς τὴν δύναμιν ', |
|
|
'+': ' σὺν ', |
|
|
' - ': ' χωρὶς ', |
|
|
'*': ' πολλάκις ', |
|
|
' / ': ' διαιρέω ', |
|
|
'=': ' ἴσον ', |
|
|
'pi': ' πῖ ', |
|
|
'<': ' ἔλαττον ', |
|
|
'>': ' μεῖζον ', |
|
|
|
|
|
'%': ' τοῖς ἑκατόν ', |
|
|
'(': ' ἀνοικτὴ παρένθεσις ', |
|
|
')': ' κλειστὴ παρένθεσις ', |
|
|
'[': ' ἀνοικτὴ ἀγκύλη ', |
|
|
']': ' κλειστὴ ἀγκύλη ', |
|
|
'{': ' ἀνοικτὴ σγουρὴ ἀγκύλη ', |
|
|
'}': ' κλειστὴ σγουρὴ ἀγκύλη ', |
|
|
'∑': ' ἄθροισμα ', |
|
|
'∫': ' ὁλοκλήρωμα ', |
|
|
'√': ' τετραγωνικὴ ῥίζα ', |
|
|
'≠': ' οὐκ ἴσον ', |
|
|
'≤': ' ἔλαττον ἢ ἴσον ', |
|
|
'≥': ' μεῖζον ἢ ἴσον ', |
|
|
'≈': ' περίπου ', |
|
|
'∞': ' ἄπειρον ', |
|
|
'€': ' εὐρώ ', |
|
|
'$': ' δολάριον ', |
|
|
'£': ' λίρα ', |
|
|
'&': ' καὶ ', |
|
|
'@': ' ἀτ ', |
|
|
'#': ' δίεση ', |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
replacements_map = { |
|
|
'grc': grc_replacements, |
|
|
'ron': ron_replacements, |
|
|
'eng': eng_replacements, |
|
|
'deu': deu_replacements, |
|
|
'fra': fra_replacements, |
|
|
'hun': hun_replacements, |
|
|
'rmc-script_latin': serbian_replacements, |
|
|
} |
|
|
|
|
|
current_replacements = replacements_map.get(lang) |
|
|
if current_replacements: |
|
|
|
|
|
|
|
|
|
|
|
sorted_replacements = sorted(current_replacements.items(), key=lambda item: len(item[0]), reverse=True) |
|
|
for old, new in sorted_replacements: |
|
|
text = text.replace(old, new) |
|
|
return text |
|
|
else: |
|
|
|
|
|
print(f"Warning: Language '{lang}' not supported for text replacement. Returning original text.") |
|
|
return text |
|
|
|
|
|
|
|
|
def _num2words(text='01234', lang=None): |
|
|
if lang == 'grc': |
|
|
return convert_numbers(text) |
|
|
return num2words(text, lang=lang) |
|
|
|
|
|
|
|
|
def transliterate_number(number_string, |
|
|
lang=None): |
|
|
if lang == 'rmc-script_latin': |
|
|
lang = 'sr' |
|
|
exponential_pronoun = ' puta deset na stepen od ' |
|
|
comma = ' tačka ' |
|
|
elif lang == 'ron': |
|
|
lang = 'ro' |
|
|
exponential_pronoun = ' tízszer a erejéig ' |
|
|
comma = ' virgulă ' |
|
|
elif lang == 'hun': |
|
|
lang = 'hu' |
|
|
exponential_pronoun = ' tízszer a erejéig ' |
|
|
comma = ' virgula ' |
|
|
elif lang == 'deu': |
|
|
exponential_pronoun = ' mal zehn hoch ' |
|
|
comma = ' komma ' |
|
|
elif lang == 'fra': |
|
|
lang = 'fr' |
|
|
exponential_pronoun = ' puissance ' |
|
|
comma = 'virgule' |
|
|
elif lang == 'grc': |
|
|
exponential_pronoun = ' εις την δυναμην του ' |
|
|
comma = 'κομμα' |
|
|
else: |
|
|
lang = lang[:2] |
|
|
exponential_pronoun = ' times ten to the power of ' |
|
|
comma = ' point ' |
|
|
|
|
|
def replace_number(match): |
|
|
prefix = match.group(1) or "" |
|
|
number_part = match.group(2) |
|
|
suffix = match.group(5) or "" |
|
|
|
|
|
try: |
|
|
if 'e' in number_part.lower(): |
|
|
base, exponent = number_part.lower().split('e') |
|
|
words = _num2words(base, lang=lang) + exponential_pronoun + _num2words(exponent, lang=lang) |
|
|
elif '.' in number_part: |
|
|
integer_part, decimal_part = number_part.split('.') |
|
|
words = _num2words(integer_part, lang=lang) + comma + " ".join( |
|
|
[_num2words(digit, lang=lang) for digit in decimal_part]) |
|
|
else: |
|
|
words = _num2words(number_part, lang=lang) |
|
|
return prefix + words + suffix |
|
|
except ValueError: |
|
|
return match.group(0) |
|
|
|
|
|
pattern = r'([^\d]*)(\d+(\.\d+)?([Ee][+-]?\d+)?)([^\d]*)' |
|
|
return re.sub(pattern, replace_number, number_string) |
|
|
|
|
|
|
|
|
language_names = ['Ancient greek', |
|
|
'English', |
|
|
'Deutsch', |
|
|
'French', |
|
|
'Hungarian', |
|
|
'Romanian', |
|
|
'Serbian (Approx.)'] |
|
|
|
|
|
|
|
|
def audionar_tts(text=None, |
|
|
lang='romanian'): |
|
|
|
|
|
|
|
|
|
|
|
lang = lang.lower() |
|
|
|
|
|
|
|
|
|
|
|
if 'hun' in lang: |
|
|
|
|
|
lang_code = 'hun' |
|
|
|
|
|
elif any([i in lang for i in ['ser', 'bosn', 'herzegov', 'montenegr', 'macedon']]): |
|
|
|
|
|
|
|
|
lang_code = 'rmc-script_latin' |
|
|
|
|
|
elif 'rom' in lang: |
|
|
|
|
|
lang_code = 'ron' |
|
|
|
|
|
elif 'ger' in lang or 'deu' in lang or 'allem' in lang: |
|
|
|
|
|
lang_code = 'deu' |
|
|
|
|
|
elif 'french' in lang: |
|
|
|
|
|
lang_code = 'fra' |
|
|
|
|
|
elif 'eng' in lang: |
|
|
|
|
|
lang_code = 'eng' |
|
|
|
|
|
elif 'ancient greek' in lang: |
|
|
|
|
|
lang_code = 'grc' |
|
|
|
|
|
else: |
|
|
|
|
|
lang_code = lang.split()[0].strip() |
|
|
|
|
|
|
|
|
|
|
|
text = only_greek_or_only_latin(text, lang=lang_code) |
|
|
|
|
|
|
|
|
|
|
|
text = transliterate_number(text, lang=lang_code) |
|
|
|
|
|
|
|
|
|
|
|
text = fix_vocals(text, lang=lang_code) |
|
|
|
|
|
|
|
|
|
|
|
global cached_lang_code, cached_net_g, cached_tokenizer |
|
|
|
|
|
if 'cached_lang_code' not in globals() or cached_lang_code != lang_code: |
|
|
cached_lang_code = lang_code |
|
|
cached_net_g = VitsModel.from_pretrained(f'facebook/mms-tts-{lang_code}').eval().to(device) |
|
|
cached_tokenizer = VitsTokenizer.from_pretrained(f'facebook/mms-tts-{lang_code}') |
|
|
|
|
|
net_g = cached_net_g |
|
|
tokenizer = cached_tokenizer |
|
|
|
|
|
total_audio = [] |
|
|
|
|
|
if not isinstance(text, list): |
|
|
text = textwrap.wrap(text, width=439) |
|
|
|
|
|
for _t in text: |
|
|
inputs = tokenizer(_t, return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
x = net_g(input_ids=inputs.input_ids.to(device), |
|
|
attention_mask=inputs.attention_mask.to(device), |
|
|
lang_code=lang_code, |
|
|
)[0, :] |
|
|
total_audio.append(x) |
|
|
|
|
|
print(f'\n\n_______________________________ {_t} {x.shape=}') |
|
|
|
|
|
x = torch.cat(total_audio).cpu().numpy() |
|
|
|
|
|
tmp_file = f'_speech.wav' |
|
|
|
|
|
audiofile.write(tmp_file, x, 16000) |
|
|
|
|
|
return tmp_file |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = 0 if torch.cuda.is_available() else "cpu" |
|
|
duration = 2 |
|
|
age_gender_model_name = "audeering/wav2vec2-large-robust-6-ft-age-gender" |
|
|
expression_model_name = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim" |
|
|
|
|
|
|
|
|
class AgeGenderHead(nn.Module): |
|
|
r"""Age-gender model head.""" |
|
|
|
|
|
def __init__(self, config, num_labels): |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.dropout = nn.Dropout(config.final_dropout) |
|
|
self.out_proj = nn.Linear(config.hidden_size, num_labels) |
|
|
|
|
|
def forward(self, features, **kwargs): |
|
|
|
|
|
x = features |
|
|
x = self.dropout(x) |
|
|
x = self.dense(x) |
|
|
x = torch.tanh(x) |
|
|
x = self.dropout(x) |
|
|
x = self.out_proj(x) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class AgeGenderModel(Wav2Vec2PreTrainedModel): |
|
|
r"""Age-gender recognition model.""" |
|
|
|
|
|
def __init__(self, config): |
|
|
|
|
|
super().__init__(config) |
|
|
|
|
|
self.config = config |
|
|
self.wav2vec2 = Wav2Vec2Model(config) |
|
|
self.age = AgeGenderHead(config, 1) |
|
|
self.gender = AgeGenderHead(config, 3) |
|
|
self.init_weights() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
frozen_cnn7, |
|
|
): |
|
|
|
|
|
hidden_states = self.wav2vec2(frozen_cnn7=frozen_cnn7) |
|
|
|
|
|
hidden_states = torch.mean(hidden_states, dim=1) |
|
|
logits_age = self.age(hidden_states) |
|
|
logits_gender = torch.softmax(self.gender(hidden_states), dim=1) |
|
|
|
|
|
return hidden_states, logits_age, logits_gender |
|
|
|
|
|
|
|
|
|
|
|
def _forward( |
|
|
self, |
|
|
frozen_cnn7=None, |
|
|
attention_mask=None): |
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
attention_mask = self._get_feature_vector_attention_mask( |
|
|
frozen_cnn7.shape[1], attention_mask, add_adapter=False |
|
|
) |
|
|
|
|
|
hidden_states, _ = self.wav2vec2.feature_projection(frozen_cnn7) |
|
|
|
|
|
hidden_states = self.wav2vec2.encoder( |
|
|
hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
output_attentions=None, |
|
|
output_hidden_states=None, |
|
|
return_dict=None, |
|
|
)[0] |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
def _forward_and_cnn7( |
|
|
self, |
|
|
input_values, |
|
|
attention_mask=None): |
|
|
|
|
|
frozen_cnn7 = self.wav2vec2.feature_extractor(input_values) |
|
|
frozen_cnn7 = frozen_cnn7.transpose(1, 2) |
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
attention_mask = self.wav2vec2._get_feature_vector_attention_mask( |
|
|
frozen_cnn7.shape[1], attention_mask, add_adapter=False |
|
|
) |
|
|
|
|
|
hidden_states, _ = self.wav2vec2.feature_projection(frozen_cnn7) |
|
|
|
|
|
hidden_states = self.wav2vec2.encoder( |
|
|
hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
output_attentions=None, |
|
|
output_hidden_states=None, |
|
|
return_dict=None, |
|
|
)[0] |
|
|
|
|
|
return hidden_states, frozen_cnn7 |
|
|
|
|
|
|
|
|
class ExpressionHead(nn.Module): |
|
|
r"""Expression model head.""" |
|
|
|
|
|
def __init__(self, config): |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.dropout = nn.Dropout(config.final_dropout) |
|
|
self.out_proj = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
def forward(self, features, **kwargs): |
|
|
|
|
|
x = features |
|
|
x = self.dropout(x) |
|
|
x = self.dense(x) |
|
|
x = torch.tanh(x) |
|
|
x = self.dropout(x) |
|
|
x = self.out_proj(x) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class ExpressionModel(Wav2Vec2PreTrainedModel): |
|
|
r"""speech expression model.""" |
|
|
|
|
|
def __init__(self, config): |
|
|
|
|
|
super().__init__(config) |
|
|
|
|
|
self.config = config |
|
|
self.wav2vec2 = Wav2Vec2Model(config) |
|
|
self.classifier = ExpressionHead(config) |
|
|
self.init_weights() |
|
|
|
|
|
def forward(self, input_values): |
|
|
hidden_states, frozen_cnn7 = self.wav2vec2(input_values) |
|
|
hidden_states = torch.mean(hidden_states, dim=1) |
|
|
logits = self.classifier(hidden_states) |
|
|
|
|
|
return hidden_states, logits, frozen_cnn7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
age_gender_model = AgeGenderModel.from_pretrained(age_gender_model_name) |
|
|
expression_processor = Wav2Vec2Processor.from_pretrained(expression_model_name) |
|
|
expression_model = ExpressionModel.from_pretrained(expression_model_name) |
|
|
|
|
|
|
|
|
|
|
|
age_gender_model.wav2vec2.forward = types.MethodType(_forward, age_gender_model) |
|
|
expression_model.wav2vec2.forward = types.MethodType(_forward_and_cnn7, expression_model) |
|
|
|
|
|
def process_func(x: np.ndarray, sampling_rate: int) -> typing.Tuple[str, dict, str]: |
|
|
|
|
|
|
|
|
y = expression_processor(x, sampling_rate=sampling_rate) |
|
|
y = y['input_values'][0] |
|
|
y = y.reshape(1, -1) |
|
|
y = torch.from_numpy(y).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
_, logits_expression, frozen_cnn7 = expression_model(y) |
|
|
|
|
|
_, logits_age, logits_gender = age_gender_model(frozen_cnn7=frozen_cnn7) |
|
|
|
|
|
|
|
|
plot_expression(logits_expression[0, 0].item(), |
|
|
logits_expression[0, 1].item(), |
|
|
logits_expression[0, 2].item()) |
|
|
expression_file = "expression.png" |
|
|
plt.savefig(expression_file) |
|
|
return ( |
|
|
f"{round(100 * logits_age[0, 0].item())} years", |
|
|
{ |
|
|
"female": logits_gender[0, 0].item(), |
|
|
"male": logits_gender[0, 1].item(), |
|
|
"child": logits_gender[0, 2].item(), |
|
|
}, |
|
|
expression_file, |
|
|
) |
|
|
|
|
|
|
|
|
def recognize(input_file): |
|
|
if input_file is None: |
|
|
raise gr.Error( |
|
|
"No audio file submitted! " |
|
|
"Please upload or record an audio file " |
|
|
"before submitting your request." |
|
|
) |
|
|
|
|
|
signal, sampling_rate = audiofile.read(input_file, duration=duration) |
|
|
|
|
|
target_rate = 16000 |
|
|
signal = audresample.resample(signal, sampling_rate, target_rate) |
|
|
|
|
|
return process_func(signal, target_rate) |
|
|
|
|
|
|
|
|
def explode(data): |
|
|
""" |
|
|
Expands a 3D array by creating gaps between voxels. |
|
|
This function is used to create the visual separation between the voxels. |
|
|
""" |
|
|
shape_orig = np.array(data.shape) |
|
|
shape_new = shape_orig * 2 - 1 |
|
|
retval = np.zeros(shape_new, dtype=data.dtype) |
|
|
retval[::2, ::2, ::2] = data |
|
|
return retval |
|
|
|
|
|
|
|
|
def explode(data): |
|
|
""" |
|
|
Expands a 3D array by adding new voxels between existing ones. |
|
|
This is used to create the gaps in the 3D plot. |
|
|
""" |
|
|
shape = data.shape |
|
|
new_shape = (2 * shape[0] - 1, 2 * shape[1] - 1, 2 * shape[2] - 1) |
|
|
new_data = np.zeros(new_shape, dtype=data.dtype) |
|
|
new_data[::2, ::2, ::2] = data |
|
|
return new_data |
|
|
|
|
|
def plot_expression(arousal, dominance, valence): |
|
|
'''_h = cuda tensor (N_PIX, N_PIX, N_PIX)''' |
|
|
|
|
|
N_PIX = 5 |
|
|
_h = np.random.rand(N_PIX, N_PIX, N_PIX) * 1e-3 |
|
|
adv = np.array([arousal, .994 - dominance, valence]).clip(0, .99) |
|
|
arousal, dominance, valence = (adv * N_PIX).astype(np.int64) |
|
|
_h[arousal, dominance, valence] = .22 |
|
|
|
|
|
filled = np.ones((N_PIX, N_PIX, N_PIX), dtype=bool) |
|
|
|
|
|
|
|
|
filled_2 = explode(filled) |
|
|
|
|
|
|
|
|
x, y, z = np.indices(np.array(filled_2.shape) + 1).astype(float) // 2 |
|
|
x[1::2, :, :] += 1 |
|
|
y[:, 1::2, :] += 1 |
|
|
z[:, :, 1::2] += 1 |
|
|
|
|
|
fig = plt.figure() |
|
|
ax = fig.add_subplot(projection='3d') |
|
|
|
|
|
f_2 = np.ones([2 * N_PIX - 1, |
|
|
2 * N_PIX - 1, |
|
|
2 * N_PIX - 1, 4], dtype=np.float64) |
|
|
f_2[:, :, :, 3] = explode(_h) |
|
|
cm = plt.get_cmap('cool') |
|
|
f_2[:, :, :, :3] = cm(f_2[:, :, :, 3])[..., :3] |
|
|
|
|
|
f_2[:, :, :, 3] = f_2[:, :, :, 3].clip(.01, .74) |
|
|
|
|
|
ecolors_2 = f_2 |
|
|
|
|
|
ax.voxels(x, y, z, filled_2, facecolors=f_2, edgecolors=.006 * ecolors_2) |
|
|
ax.set_aspect('equal') |
|
|
ax.set_zticks([0, N_PIX]) |
|
|
ax.set_xticks([0, N_PIX]) |
|
|
ax.set_yticks([0, N_PIX]) |
|
|
|
|
|
ax.set_zticklabels([f'{n/N_PIX:.2f}'[0:] for n in ax.get_zticks()]) |
|
|
ax.set_zlabel('valence', fontsize=10, labelpad=0) |
|
|
ax.set_xticklabels([f'{n/N_PIX:.2f}' for n in ax.get_xticks()]) |
|
|
ax.set_xlabel('arousal', fontsize=10, labelpad=7) |
|
|
|
|
|
ax.set_yticklabels([f'{1-n/N_PIX:.2f}' for n in ax.get_yticks()], rotation=90) |
|
|
ax.set_ylabel('dominance', fontsize=10, labelpad=10) |
|
|
ax.grid(False) |
|
|
|
|
|
ax.plot([N_PIX, N_PIX], [0, N_PIX + .2], [N_PIX, N_PIX], 'g', linewidth=1) |
|
|
ax.plot([0, N_PIX], [N_PIX, N_PIX + .24], [N_PIX, N_PIX], 'k', linewidth=1) |
|
|
|
|
|
|
|
|
ax.plot([0, 0], [0, N_PIX], [N_PIX, N_PIX], 'darkred', linewidth=1) |
|
|
ax.plot([0, N_PIX], [0, 0], [N_PIX, N_PIX], 'darkblue', linewidth=1) |
|
|
|
|
|
|
|
|
|
|
|
ax.xaxis.set_pane_color((0.8, 0.8, 0.8, 0.5)) |
|
|
ax.yaxis.set_pane_color((0.8, 0.8, 0.8, 0.5)) |
|
|
ax.zaxis.set_pane_color((0.8, 0.8, 0.8, 0.0)) |
|
|
|
|
|
|
|
|
ax.set_xlim(0, N_PIX) |
|
|
ax.set_ylim(0, N_PIX) |
|
|
ax.set_zlim(0, N_PIX) |
|
|
|
|
|
|
|
|
|
|
|
VOICES = [f'wav/{vox}' for vox in os.listdir('wav')] |
|
|
_tts = StyleTTS2().to('cpu') |
|
|
|
|
|
def only_greek_or_only_latin(text, lang='grc'): |
|
|
''' |
|
|
str: The converted string in the specified target script. |
|
|
Characters not found in any mapping are preserved as is. |
|
|
Latin accented characters in the input (e.g., 'É', 'ü') will |
|
|
be preserved in their lowercase form (e.g., 'é', 'ü') if |
|
|
converting to Latin. |
|
|
''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
latin_to_greek_map = { |
|
|
'a': 'α', 'b': 'β', 'g': 'γ', 'd': 'δ', 'e': 'ε', |
|
|
'ch': 'τσο', |
|
|
'z': 'ζ', 'h': 'χ', 'i': 'ι', 'k': 'κ', 'l': 'λ', |
|
|
'm': 'μ', 'n': 'ν', 'x': 'ξ', 'o': 'ο', 'p': 'π', |
|
|
'v': 'β', 'sc': 'σκ', 'r': 'ρ', 's': 'σ', 't': 'τ', |
|
|
'u': 'ου', 'f': 'φ', 'c': 'σ', 'w': 'β', 'y': 'γ', |
|
|
} |
|
|
|
|
|
greek_to_latin_map = { |
|
|
'ου': 'ou', |
|
|
'α': 'a', 'β': 'v', 'γ': 'g', 'δ': 'd', 'ε': 'e', |
|
|
'ζ': 'z', 'η': 'i', 'θ': 'th', 'ι': 'i', 'κ': 'k', |
|
|
'λ': 'l', 'μ': 'm', 'ν': 'n', 'ξ': 'x', 'ο': 'o', |
|
|
'π': 'p', 'ρ': 'r', 'σ': 's', 'τ': 't', 'υ': 'y', |
|
|
'φ': 'f', 'χ': 'ch', 'ψ': 'ps', 'ω': 'o', |
|
|
'ς': 's', |
|
|
} |
|
|
|
|
|
cyrillic_to_latin_map = { |
|
|
'а': 'a', 'б': 'b', 'в': 'v', 'г': 'g', 'д': 'd', 'е': 'e', 'ё': 'yo', 'ж': 'zh', |
|
|
'з': 'z', 'и': 'i', 'й': 'y', 'к': 'k', 'л': 'l', 'м': 'm', 'н': 'n', 'о': 'o', |
|
|
'п': 'p', 'р': 'r', 'с': 's', 'т': 't', 'у': 'u', 'ф': 'f', 'х': 'kh', 'ц': 'ts', |
|
|
'ч': 'ch', 'ш': 'sh', 'щ': 'shch', 'ъ': '', 'ы': 'y', 'ь': '', 'э': 'e', 'ю': 'yu', |
|
|
'я': 'ya', |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
cyrillic_to_greek_map = { |
|
|
'а': 'α', 'б': 'β', 'в': 'β', 'г': 'γ', 'д': 'δ', 'е': 'ε', 'ё': 'ιο', 'ж': 'ζ', |
|
|
'з': 'ζ', 'и': 'ι', 'й': 'ι', 'κ': 'κ', 'λ': 'λ', 'м': 'μ', 'н': 'ν', 'о': 'ο', |
|
|
'π': 'π', 'ρ': 'ρ', 'σ': 'σ', 'τ': 'τ', 'у': 'ου', 'ф': 'φ', 'х': 'χ', 'ц': 'τσ', |
|
|
'ч': 'τσ', |
|
|
'ш': 'σ', 'щ': 'σ', |
|
|
'ъ': '', 'ы': 'ι', 'ь': '', 'э': 'ε', 'ю': 'ιου', |
|
|
'я': 'ια', |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
lowercased_text = text.lower() |
|
|
output_chars = [] |
|
|
current_index = 0 |
|
|
|
|
|
if lang == 'grc': |
|
|
|
|
|
conversion_map = {**latin_to_greek_map, **cyrillic_to_greek_map} |
|
|
|
|
|
|
|
|
sorted_source_keys = sorted( |
|
|
list(latin_to_greek_map.keys()) + list(cyrillic_to_greek_map.keys()), |
|
|
key=len, |
|
|
reverse=True |
|
|
) |
|
|
|
|
|
while current_index < len(lowercased_text): |
|
|
found_conversion = False |
|
|
for key in sorted_source_keys: |
|
|
if lowercased_text.startswith(key, current_index): |
|
|
output_chars.append(conversion_map[key]) |
|
|
current_index += len(key) |
|
|
found_conversion = True |
|
|
break |
|
|
if not found_conversion: |
|
|
|
|
|
|
|
|
output_chars.append(lowercased_text[current_index]) |
|
|
current_index += 1 |
|
|
return ''.join(output_chars) |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
combined_to_latin_map = {**greek_to_latin_map, **cyrillic_to_latin_map} |
|
|
|
|
|
|
|
|
sorted_source_keys = sorted( |
|
|
list(greek_to_latin_map.keys()) + list(cyrillic_to_latin_map.keys()), |
|
|
key=len, |
|
|
reverse=True |
|
|
) |
|
|
|
|
|
while current_index < len(lowercased_text): |
|
|
found_conversion = False |
|
|
for key in sorted_source_keys: |
|
|
if lowercased_text.startswith(key, current_index): |
|
|
latin_equivalent = combined_to_latin_map[key] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if key in greek_to_latin_map: |
|
|
normalized_latin = unicodedata.normalize('NFD', latin_equivalent) |
|
|
stripped_latin = ''.join(c for c in normalized_latin if not unicodedata.combining(c)) |
|
|
output_chars.append(stripped_latin) |
|
|
else: |
|
|
output_chars.append(latin_equivalent) |
|
|
|
|
|
current_index += len(key) |
|
|
found_conversion = True |
|
|
break |
|
|
|
|
|
if not found_conversion: |
|
|
|
|
|
|
|
|
|
|
|
output_chars.append(lowercased_text[current_index]) |
|
|
current_index += 1 |
|
|
|
|
|
return ''.join(output_chars) |
|
|
|
|
|
|
|
|
def other_tts(text='Hallov worlds Far over the', |
|
|
ref_s='wav/af_ZA_google-nwu_0184.wav'): |
|
|
|
|
|
text = only_greek_or_only_latin(text, lang='eng') |
|
|
|
|
|
x = _tts.inference(text, ref_s=ref_s)[0:1, 0, :] |
|
|
|
|
|
x = torch.cat([.99 * x, |
|
|
.94 * x], 0).cpu().numpy() |
|
|
|
|
|
|
|
|
|
|
|
tmp_file = f'_speech.wav' |
|
|
|
|
|
audiofile.write(tmp_file, x, 24000) |
|
|
|
|
|
return tmp_file |
|
|
|
|
|
|
|
|
def update_selected_voice(voice_filename): |
|
|
return 'wav/' + voice_filename + '.wav' |
|
|
|
|
|
|
|
|
description = ( |
|
|
"Estimate **age**, **gender**, and **expression** " |
|
|
"of the speaker contained in an audio file or microphone recording. \n" |
|
|
f"The model [{age_gender_model_name}]" |
|
|
f"(https://huggingface.co/{age_gender_model_name}) " |
|
|
"recognises age and gender, " |
|
|
f"whereas [{expression_model_name}]" |
|
|
f"(https://huggingface.co/{expression_model_name}) " |
|
|
"recognises the expression dimensions arousal, dominance, and valence. " |
|
|
) |
|
|
|
|
|
css_buttons = """ |
|
|
.cool-button { |
|
|
background-color: #1a2a40; /* Slightly lighter dark blue */ |
|
|
color: white; |
|
|
padding: 15px 32px; |
|
|
text-align: center; |
|
|
font-size: 16px; |
|
|
border-radius: 12px; |
|
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.4); |
|
|
transition: all 0.3s ease-in-out; |
|
|
border: none; |
|
|
cursor: pointer; |
|
|
} |
|
|
.cool-button:hover { |
|
|
background-color: #1a2a40; /* Slightly lighter dark blue */ |
|
|
transform: scale(1.05); |
|
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.4); |
|
|
} |
|
|
.cool-row { |
|
|
margin-bottom: 10px; |
|
|
} |
|
|
""" |
|
|
|
|
|
with gr.Blocks(theme='huggingface', css=css_buttons) as demo: |
|
|
with gr.Tab(label="other TTS"): |
|
|
|
|
|
selected_voice = gr.State(value='wav/en_US_m-ailabs_mary_ann.wav') |
|
|
|
|
|
with gr.Row(): |
|
|
voice_info = gr.Markdown(f'Vox = `{selected_voice.value}`') |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
text_input = gr.Textbox( |
|
|
label="Enter text for TTS:", |
|
|
placeholder="Type your message here...", |
|
|
lines=4, |
|
|
value="Farover the misty mountains cold too dungeons deep and caverns old.", |
|
|
) |
|
|
generate_button = gr.Button("Generate Audio", variant="primary") |
|
|
|
|
|
output_audio = gr.Audio(label="TTS Output") |
|
|
|
|
|
with gr.Column(): |
|
|
voice_buttons = [] |
|
|
for i in range(0, len(VOICES), 7): |
|
|
with gr.Row(elem_classes=["cool-row"]): |
|
|
for voice_filename in VOICES[i:i+7]: |
|
|
voice_filename = voice_filename[4:-4] |
|
|
button = gr.Button(voice_filename, elem_classes=["cool-button"]) |
|
|
|
|
|
button.click( |
|
|
fn=update_selected_voice, |
|
|
inputs=[gr.Textbox(value=voice_filename, visible=False)], |
|
|
outputs=[selected_voice] |
|
|
) |
|
|
button.click( |
|
|
fn=lambda v=voice_filename: f'Vox = `{v}`', |
|
|
inputs=None, |
|
|
outputs=voice_info |
|
|
) |
|
|
voice_buttons.append(button) |
|
|
|
|
|
generate_button.click( |
|
|
fn=other_tts, |
|
|
inputs=[text_input, selected_voice], |
|
|
outputs=output_audio |
|
|
) |
|
|
|
|
|
with gr.Tab(label="Speech Analysis"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown(description) |
|
|
input = gr.Audio( |
|
|
sources=["upload", "microphone"], |
|
|
type="filepath", |
|
|
label="Audio input", |
|
|
min_length=0.025, |
|
|
) |
|
|
gr.Examples( |
|
|
[ |
|
|
"wav/female-46-neutral.wav", |
|
|
"wav/female-20-happy.wav", |
|
|
"wav/male-60-angry.wav", |
|
|
"wav/male-27-sad.wav", |
|
|
], |
|
|
[input], |
|
|
label="Examples from CREMA-D, ODbL v1.0 license", |
|
|
) |
|
|
gr.Markdown("Only the first two seconds of the audio will be processed.") |
|
|
submit_btn = gr.Button(value="Submit") |
|
|
with gr.Column(): |
|
|
output_age = gr.Textbox(label="Age") |
|
|
output_gender = gr.Label(label="Gender") |
|
|
output_expression = gr.Image(label="Expression") |
|
|
|
|
|
outputs = [output_age, output_gender, output_expression] |
|
|
submit_btn.click(recognize, input, outputs) |
|
|
|
|
|
|
|
|
with gr.Tab("audionar TTS"): |
|
|
with gr.Row(): |
|
|
text_input = gr.Textbox( |
|
|
lines=4, |
|
|
value='Η γρηγορη καφετι αλεπου πειδαει πανω απο τον τεμπελη σκυλο.', |
|
|
label="Type text for TTS" |
|
|
) |
|
|
lang_dropdown = gr.Dropdown( |
|
|
choices=language_names, |
|
|
label="TTS language", |
|
|
value="Ancient greek", |
|
|
) |
|
|
|
|
|
|
|
|
tts_button = gr.Button("Generate Audio") |
|
|
|
|
|
|
|
|
audio_output = gr.Audio(label="Generated Audio") |
|
|
|
|
|
|
|
|
tts_button.click( |
|
|
fn=audionar_tts, |
|
|
inputs=[text_input, lang_dropdown], |
|
|
outputs=audio_output |
|
|
) |
|
|
|
|
|
demo.launch(debug=True) |
|
|
|