Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import pipeline | |
| import pandas as pd | |
| import plotly.express as px | |
| # ------------------------------ | |
| # Load pretrained models | |
| # ------------------------------ | |
| text_classifier = pipeline( | |
| "text-classification", | |
| model="j-hartmann/emotion-english-distilroberta-base", | |
| return_all_scores=True | |
| ) | |
| audio_classifier = pipeline( | |
| "audio-classification", | |
| model="superb/wav2vec2-base-superb-er" | |
| ) | |
| # ------------------------------ | |
| # Map emotion to emoji | |
| # ------------------------------ | |
| EMOJI_MAP = { | |
| "anger": "π‘", | |
| "disgust": "π€’", | |
| "fear": "π¨", | |
| "joy": "π", | |
| "neutral": "π", | |
| "sadness": "π’", | |
| "surprise": "π²", | |
| "hap": "π", # for audio model | |
| "neu": "π", | |
| "sad": "π’", | |
| "ang": "π‘" | |
| } | |
| # ------------------------------ | |
| # Fusion function | |
| # ------------------------------ | |
| def fuse_predictions(text_preds=None, audio_preds=None, w_text=0.5, w_audio=0.5): | |
| labels = set() | |
| if text_preds: | |
| labels |= {p['label'] for p in text_preds} | |
| if audio_preds: | |
| labels |= {p['label'] for p in audio_preds} | |
| scores = {l: 0.0 for l in labels} | |
| def normalize(preds): | |
| s = sum(p['score'] for p in preds) | |
| return {p['label']: p['score']/s for p in preds} | |
| if text_preds: | |
| t_norm = normalize(text_preds) | |
| for l in labels: | |
| scores[l] += w_text * t_norm.get(l, 0) | |
| if audio_preds: | |
| a_norm = normalize(audio_preds) | |
| for l in labels: | |
| scores[l] += w_audio * a_norm.get(l, 0) | |
| best = max(scores.items(), key=lambda x: x[1]) if scores else ("none", 0) | |
| return {"fused_label": best[0], "fused_score": round(best[1], 3), "all_scores": scores} | |
| # ------------------------------ | |
| # Create bar chart | |
| # ------------------------------ | |
| def make_bar_chart(scores_dict, title="Emotion Scores"): | |
| df = pd.DataFrame({ | |
| "Emotion": list(scores_dict.keys()), | |
| "Score": list(scores_dict.values()) | |
| }) | |
| fig = px.bar(df, x="Emotion", y="Score", text="Score", | |
| title=title, range_y=[0,1], | |
| color="Emotion", color_discrete_sequence=px.colors.qualitative.Bold) | |
| fig.update_traces(texttemplate='%{text:.2f}', textposition='outside') | |
| fig.update_layout(yaxis_title="Probability", xaxis_title="Emotion", showlegend=False) | |
| return fig | |
| # ------------------------------ | |
| # Prediction function | |
| # ------------------------------ | |
| def predict(text, audio, w_text, w_audio): | |
| text_preds, audio_preds = None, None | |
| if text: | |
| text_preds = text_classifier(text)[0] | |
| if audio: | |
| audio_preds = audio_classifier(audio) | |
| fused = fuse_predictions(text_preds, audio_preds, w_text, w_audio) | |
| # Display final predicted emotion with emoji | |
| label = fused['fused_label'] | |
| emoji = EMOJI_MAP.get(label, "") | |
| final_emotion = f"### Final Predicted Emotion: {label.upper()} {emoji} (score: {fused['fused_score']})" | |
| # Bar charts | |
| charts = [] | |
| if text_preds: | |
| charts.append(make_bar_chart({p['label']: p['score'] for p in text_preds}, "Text Emotion Scores")) | |
| if audio_preds: | |
| charts.append(make_bar_chart({p['label']: p['score'] for p in audio_preds}, "Audio Emotion Scores")) | |
| charts.append(make_bar_chart(fused['all_scores'], "Fused Emotion Scores")) | |
| return final_emotion, charts | |
| # ------------------------------ | |
| # Build Gradio interface | |
| # ------------------------------ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## π Multimodal Emotion Classification (Text + Speech)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| txt = gr.Textbox(label="Text input", placeholder="Type something emotional...") | |
| aud = gr.Audio(type="filepath", label="Upload speech (wav/mp3)") | |
| w1 = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Text weight (w_text)") | |
| w2 = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Audio weight (w_audio)") | |
| btn = gr.Button("Predict") | |
| with gr.Column(): | |
| final_label = gr.Markdown(label="Predicted Emotion") | |
| chart_output = gr.Plot(label="Emotion Scores") | |
| # Button click triggers prediction | |
| btn.click(fn=predict, inputs=[txt, aud, w1, w2], outputs=[final_label, chart_output]) | |
| demo.launch() | |