Spaces:
Paused
Paused
| import re | |
| import os | |
| import gradio as gr | |
| import librosa | |
| import numpy as np | |
| from transformers import AutoTokenizer,ViTImageProcessor | |
| from unidecode import unidecode | |
| from models import * | |
| tok = AutoTokenizer.from_pretrained("readerbench/RoBERT-base") | |
| processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') | |
| def preprocess(x): | |
| """Preprocess input string x""" | |
| s = unidecode(x) | |
| s = str.lower(s) | |
| s = re.sub(r"\[[a-z]+\]","", s) | |
| s = re.sub(r"\*","", s) | |
| s = re.sub(r"[^a-zA-Z0-9]+"," ",s) | |
| s = re.sub(r" +"," ",s) | |
| s = re.sub(r"(.)\1+",r"\1",s) | |
| return s | |
| label_names = ["ABUSE", "INSULT", "OTHER", "PROFANITY"] | |
| audio_label_names = ["Laughter", "Sigh", "Cough", "Throat clearing", "Sneeze", "Sniff"] | |
| def ssl_predict(in_text, model_type): | |
| """main predict function""" | |
| preprocessed = preprocess(in_text) | |
| toks = tok( | |
| preprocessed, | |
| padding="max_length", | |
| max_length=96, | |
| truncation=True, | |
| return_tensors="tf" | |
| ) | |
| preds = None | |
| if model_type == "fixmatch": | |
| model = FixMatchTune(encoder_name="readerbench/RoBERT-base") | |
| model.load_weights("./checkpoints/fixmatch_tune") | |
| preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False) | |
| elif model_type == "freematch": | |
| model = FixMatchTune(encoder_name="andrei-saceleanu/ro-offense-freematch") | |
| model.cls_head.load_weights("./checkpoints/freematch_tune") | |
| preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False) | |
| elif model_type == "mixmatch": | |
| model = MixMatch(bert_model="andrei-saceleanu/ro-offense-mixmatch") | |
| model.cls_head.load_weights("./checkpoints/mixmatch") | |
| preds = model([toks["input_ids"],toks["attention_mask"]], training=False) | |
| elif model_type == "contrastive_reg": | |
| model = FixMatchTune(encoder_name="readerbench/RoBERT-base") | |
| model.load_weights("./checkpoints/contrastive") | |
| preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False) | |
| elif model_type == "label_propagation": | |
| model = LPModel() | |
| model.load_weights("./checkpoints/label_prop") | |
| preds = model([toks["input_ids"],toks["attention_mask"]], training=False) | |
| probs = list(preds[0].numpy()) | |
| d = {} | |
| for k, v in zip(label_names, probs): | |
| d[k] = float(v) | |
| return d | |
| def ssl_predict2(audio_file, model_type): | |
| """main predict function""" | |
| signal, sr = librosa.load(audio_file.name, sr=16000) | |
| length = 5 * 16000 | |
| if len(signal) < length: | |
| signal = np.pad(signal,(0,length-len(signal)),'constant') | |
| else: | |
| signal = signal[:length] | |
| spectrogram = librosa.feature.melspectrogram(y=signal, sr=sr, n_mels=128) | |
| spectrogram = librosa.power_to_db(S=spectrogram, ref=np.max) | |
| spectrogram_min, spectrogram_max = spectrogram.min(), spectrogram.max() | |
| spectrogram = (spectrogram - spectrogram_min) / (spectrogram_max - spectrogram_min) | |
| spectrogram = spectrogram.astype("float32") | |
| inputs = processor.preprocess( | |
| np.repeat(spectrogram[np.newaxis,:,:,np.newaxis],3,-1), | |
| image_mean=(-3.05,-3.05,-3.05), | |
| image_std=(2.33,2.33,2.33), | |
| return_tensors="tf" | |
| ) | |
| preds = None | |
| if model_type == "fixmatch": | |
| model = AudioFixMatch(encoder_name="andrei-saceleanu/vit-base-fixmatch") | |
| model.cls_head.load_weights("./checkpoints/audio_fixmatch") | |
| preds, _ = model(inputs["pixel_values"], training=False) | |
| elif model_type == "freematch": | |
| model = AudioFixMatch(encoder_name="andrei-saceleanu/vit-base-freematch") | |
| model.cls_head.load_weights("./checkpoints/audio_freematch") | |
| preds, _ = model(inputs["pixel_values"], training=False) | |
| elif model_type == "mixmatch": | |
| model = AudioMixMatch(encoder_name="andrei-saceleanu/vit-base-mixmatch") | |
| model.cls_head.load_weights("./checkpoints/audio_mixmatch") | |
| preds = model(inputs["pixel_values"], training=False) | |
| probs = list(preds[0].numpy()) | |
| d = {} | |
| for k, v in zip(audio_label_names, probs): | |
| d[k] = float(v) | |
| return d | |
| text_types = ["text", "password"] | |
| with open(file="examples.txt", mode="r", encoding="UTF-8") as fin: | |
| lines = [elem[:-1] for elem in fin.readlines()] | |
| DATA_DIR = os.path.abspath("./audio_data") | |
| with open(file="audio_examples.txt", mode="r", encoding="UTF-8") as fin: | |
| lines2 = [os.path.join(DATA_DIR, elem.strip()) for elem in fin.readlines()] | |
| with gr.Blocks() as ssl_interface: | |
| with gr.Tab("Text (RO-Offense)"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| in_text = gr.Textbox(label="Input text",type="password") | |
| safe_view = gr.Checkbox(value=True,label="Safe view") | |
| model_list = gr.Dropdown( | |
| choices=["fixmatch", "freematch", "mixmatch", "contrastive_reg", "label_propagation"], | |
| max_choices=1, | |
| label="Training method", | |
| allow_custom_value=False, | |
| info="Select trained model according to different SSL techniques from paper", | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button(value="Clear") | |
| submit_btn = gr.Button(value="Submit") | |
| ds = gr.Dataset( | |
| components=[gr.Textbox(visible=False),gr.Textbox(visible=False)], | |
| headers=["Id","Expected class"], | |
| samples=[["1","ABUSE"],["2","INSULT"],["3","PROFANITY"],["4","OTHER"]], | |
| type="index" | |
| ) | |
| with gr.Column(): | |
| out_field = gr.Label(num_top_classes=4, label="Prediction") | |
| safe_view.change( | |
| fn= lambda checked: gr.update(type=text_types[int(checked)]), | |
| inputs=safe_view, | |
| outputs=in_text | |
| ) | |
| ds.click( | |
| fn=lambda idx: gr.update(value=lines[idx].split("##")[0]), | |
| inputs=ds, | |
| outputs=in_text | |
| ) | |
| submit_btn.click( | |
| fn=ssl_predict, | |
| inputs=[in_text, model_list], | |
| outputs=[out_field] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: [None for _ in range(2)], | |
| inputs=None, | |
| outputs=[in_text, out_field], | |
| queue=False | |
| ) | |
| with gr.Tab("Audio (VocalSound)"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_file = gr.File( | |
| label="Input audio", | |
| file_count="single", | |
| file_types=["audio"] | |
| ) | |
| model_list2 = gr.Dropdown( | |
| choices=["fixmatch", "freematch", "mixmatch"], | |
| max_choices=1, | |
| label="Training method", | |
| allow_custom_value=False, | |
| info="Select trained model according to different SSL techniques from paper", | |
| ) | |
| with gr.Row(): | |
| clear_btn2 = gr.Button(value="Clear") | |
| submit_btn2 = gr.Button(value="Submit") | |
| ds2 = gr.Dataset( | |
| components=[gr.Textbox(visible=False),gr.Textbox(visible=False)], | |
| headers=["Id","Expected class"], | |
| samples=[["1","Laughter"],["2","Cough"],["3","Sneeze"],["4","Throatclearing"]], | |
| type="index" | |
| ) | |
| with gr.Column(): | |
| out_field2 = gr.Label(num_top_classes=6, label="Prediction") | |
| submit_btn2.click( | |
| fn=ssl_predict2, | |
| inputs=[audio_file, model_list2], | |
| outputs=[out_field2] | |
| ) | |
| clear_btn2.click( | |
| fn=lambda: [None for _ in range(2)], | |
| inputs=None, | |
| outputs=[audio_file, out_field2], | |
| queue=False | |
| ) | |
| ds2.click( | |
| fn=lambda idx: gr.update(value=lines2[idx]), | |
| inputs=ds2, | |
| outputs=audio_file | |
| ) | |
| ssl_interface.launch(server_name="0.0.0.0", server_port=7860) | |