Spaces:
Runtime error
Runtime error
| import evaluate | |
| from evaluate.utils import launch_gradio_widget | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForSequenceClassification, pipeline, RobertaForSequenceClassification, RobertaTokenizer, AutoTokenizer | |
| # pull in emotion detection | |
| # --- Add element for specification | |
| # pull in text classification | |
| # --- Add custom labels | |
| # --- Associate labels with radio elements | |
| # add logic to initiate mock notificaiton when detected | |
| # pull in misophonia-specific model | |
| # Create a Gradio interface with audio file and text inputs | |
| def classify_toxicity(audio_file, text_input, classify_anxiety): | |
| # Transcribe the audio file using Whisper ASR | |
| if audio_file != None: | |
| whisper_module = evaluate.load("whisper") | |
| transcription_results = whisper_module.compute(uploaded=audio_file) | |
| # Extract the transcribed text | |
| transcribed_text = transcription_results["transcription"] | |
| else: | |
| transcribed_text = text_input | |
| # Load the selected toxicity classification model | |
| toxicity_module = evaluate.load("toxicity", "facebook/roberta-hate-speech-dynabench-r4-target") | |
| #toxicity_module = evaluate.load("toxicity", 'DaNLP/da-electra-hatespeech-detection', module_type="measurement") | |
| toxicity_results = toxicity_module.compute(predictions=[transcribed_text]) | |
| toxicity_score = toxicity_results["toxicity"][0] | |
| print(toxicity_score) | |
| # Text classification | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| classifiation_model = pipeline("zero-shot-classification", model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli") | |
| sequence_to_classify = transcribed_text | |
| candidate_labels = classify_anxiety | |
| classification_output = classifier(sequence_to_classify, candidate_labels, multi_label=False) | |
| print(classification_output) | |
| return toxicity_score, transcribed_text | |
| # return f"Toxicity Score ({available_models[selected_model]}): {toxicity_score:.4f}" | |
| with gr.Blocks() as iface: | |
| with gr.Column(): | |
| classify = gr.Radio(["racial identity hate", "LGBTQ+ hate", "sexually explicit", "misophonia"]) | |
| with gr.Column(): | |
| aud_input = gr.Audio(source="upload", type="filepath", label="Upload Audio File") | |
| text = gr.Textbox(label="Enter Text", placeholder="Enter text here...") | |
| submit_btn = gr.Button(label="Run") | |
| with gr.Column(): | |
| out_text = gr.Textbox() | |
| submit_btn.click(fn=classify_toxicity, inputs=[aud_input, text, classify], outputs=out_text) | |
| iface.launch() |