Spaces:
Sleeping
Sleeping
File size: 4,288 Bytes
9258fc6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import gradio as gr
from transformers import pipeline
import pandas as pd
import plotly.express as px
# ------------------------------
# Load pretrained models
# ------------------------------
text_classifier = pipeline(
"text-classification",
model="j-hartmann/emotion-english-distilroberta-base",
return_all_scores=True
)
audio_classifier = pipeline(
"audio-classification",
model="superb/wav2vec2-base-superb-er"
)
# ------------------------------
# Map emotion to emoji
# ------------------------------
EMOJI_MAP = {
"anger": "π‘",
"disgust": "π€’",
"fear": "π¨",
"joy": "π",
"neutral": "π",
"sadness": "π’",
"surprise": "π²",
"hap": "π", # for audio model
"neu": "π",
"sad": "π’",
"ang": "π‘"
}
# ------------------------------
# Fusion function
# ------------------------------
def fuse_predictions(text_preds=None, audio_preds=None, w_text=0.5, w_audio=0.5):
labels = set()
if text_preds:
labels |= {p['label'] for p in text_preds}
if audio_preds:
labels |= {p['label'] for p in audio_preds}
scores = {l: 0.0 for l in labels}
def normalize(preds):
s = sum(p['score'] for p in preds)
return {p['label']: p['score']/s for p in preds}
if text_preds:
t_norm = normalize(text_preds)
for l in labels:
scores[l] += w_text * t_norm.get(l, 0)
if audio_preds:
a_norm = normalize(audio_preds)
for l in labels:
scores[l] += w_audio * a_norm.get(l, 0)
best = max(scores.items(), key=lambda x: x[1]) if scores else ("none", 0)
return {"fused_label": best[0], "fused_score": round(best[1], 3), "all_scores": scores}
# ------------------------------
# Create bar chart
# ------------------------------
def make_bar_chart(scores_dict, title="Emotion Scores"):
df = pd.DataFrame({
"Emotion": list(scores_dict.keys()),
"Score": list(scores_dict.values())
})
fig = px.bar(df, x="Emotion", y="Score", text="Score",
title=title, range_y=[0,1],
color="Emotion", color_discrete_sequence=px.colors.qualitative.Bold)
fig.update_traces(texttemplate='%{text:.2f}', textposition='outside')
fig.update_layout(yaxis_title="Probability", xaxis_title="Emotion", showlegend=False)
return fig
# ------------------------------
# Prediction function
# ------------------------------
def predict(text, audio, w_text, w_audio):
text_preds, audio_preds = None, None
if text:
text_preds = text_classifier(text)[0]
if audio:
audio_preds = audio_classifier(audio)
fused = fuse_predictions(text_preds, audio_preds, w_text, w_audio)
# Display final predicted emotion with emoji
label = fused['fused_label']
emoji = EMOJI_MAP.get(label, "")
final_emotion = f"### Final Predicted Emotion: {label.upper()} {emoji} (score: {fused['fused_score']})"
# Bar charts
charts = []
if text_preds:
charts.append(make_bar_chart({p['label']: p['score'] for p in text_preds}, "Text Emotion Scores"))
if audio_preds:
charts.append(make_bar_chart({p['label']: p['score'] for p in audio_preds}, "Audio Emotion Scores"))
charts.append(make_bar_chart(fused['all_scores'], "Fused Emotion Scores"))
return final_emotion, charts
# ------------------------------
# Build Gradio interface
# ------------------------------
with gr.Blocks() as demo:
gr.Markdown("## π Multimodal Emotion Classification (Text + Speech)")
with gr.Row():
with gr.Column():
txt = gr.Textbox(label="Text input", placeholder="Type something emotional...")
aud = gr.Audio(type="filepath", label="Upload speech (wav/mp3)")
w1 = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Text weight (w_text)")
w2 = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Audio weight (w_audio)")
btn = gr.Button("Predict")
with gr.Column():
final_label = gr.Markdown(label="Predicted Emotion")
chart_output = gr.Plot(label="Emotion Scores")
# Button click triggers prediction
btn.click(fn=predict, inputs=[txt, aud, w1, w2], outputs=[final_label, chart_output])
demo.launch()
|