sachin7777777 commited on
Commit
9258fc6
Β·
verified Β·
1 Parent(s): cd42213

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ import pandas as pd
4
+ import plotly.express as px
5
+
6
+ # ------------------------------
7
+ # Load pretrained models
8
+ # ------------------------------
9
+ text_classifier = pipeline(
10
+ "text-classification",
11
+ model="j-hartmann/emotion-english-distilroberta-base",
12
+ return_all_scores=True
13
+ )
14
+
15
+ audio_classifier = pipeline(
16
+ "audio-classification",
17
+ model="superb/wav2vec2-base-superb-er"
18
+ )
19
+
20
+ # ------------------------------
21
+ # Map emotion to emoji
22
+ # ------------------------------
23
+ EMOJI_MAP = {
24
+ "anger": "😑",
25
+ "disgust": "🀒",
26
+ "fear": "😨",
27
+ "joy": "πŸ˜„",
28
+ "neutral": "😐",
29
+ "sadness": "😒",
30
+ "surprise": "😲",
31
+ "hap": "πŸ˜„", # for audio model
32
+ "neu": "😐",
33
+ "sad": "😒",
34
+ "ang": "😑"
35
+ }
36
+
37
+ # ------------------------------
38
+ # Fusion function
39
+ # ------------------------------
40
+ def fuse_predictions(text_preds=None, audio_preds=None, w_text=0.5, w_audio=0.5):
41
+ labels = set()
42
+ if text_preds:
43
+ labels |= {p['label'] for p in text_preds}
44
+ if audio_preds:
45
+ labels |= {p['label'] for p in audio_preds}
46
+ scores = {l: 0.0 for l in labels}
47
+
48
+ def normalize(preds):
49
+ s = sum(p['score'] for p in preds)
50
+ return {p['label']: p['score']/s for p in preds}
51
+
52
+ if text_preds:
53
+ t_norm = normalize(text_preds)
54
+ for l in labels:
55
+ scores[l] += w_text * t_norm.get(l, 0)
56
+ if audio_preds:
57
+ a_norm = normalize(audio_preds)
58
+ for l in labels:
59
+ scores[l] += w_audio * a_norm.get(l, 0)
60
+
61
+ best = max(scores.items(), key=lambda x: x[1]) if scores else ("none", 0)
62
+ return {"fused_label": best[0], "fused_score": round(best[1], 3), "all_scores": scores}
63
+
64
+ # ------------------------------
65
+ # Create bar chart
66
+ # ------------------------------
67
+ def make_bar_chart(scores_dict, title="Emotion Scores"):
68
+ df = pd.DataFrame({
69
+ "Emotion": list(scores_dict.keys()),
70
+ "Score": list(scores_dict.values())
71
+ })
72
+ fig = px.bar(df, x="Emotion", y="Score", text="Score",
73
+ title=title, range_y=[0,1],
74
+ color="Emotion", color_discrete_sequence=px.colors.qualitative.Bold)
75
+ fig.update_traces(texttemplate='%{text:.2f}', textposition='outside')
76
+ fig.update_layout(yaxis_title="Probability", xaxis_title="Emotion", showlegend=False)
77
+ return fig
78
+
79
+ # ------------------------------
80
+ # Prediction function
81
+ # ------------------------------
82
+ def predict(text, audio, w_text, w_audio):
83
+ text_preds, audio_preds = None, None
84
+ if text:
85
+ text_preds = text_classifier(text)[0]
86
+ if audio:
87
+ audio_preds = audio_classifier(audio)
88
+ fused = fuse_predictions(text_preds, audio_preds, w_text, w_audio)
89
+
90
+ # Display final predicted emotion with emoji
91
+ label = fused['fused_label']
92
+ emoji = EMOJI_MAP.get(label, "")
93
+ final_emotion = f"### Final Predicted Emotion: {label.upper()} {emoji} (score: {fused['fused_score']})"
94
+
95
+ # Bar charts
96
+ charts = []
97
+ if text_preds:
98
+ charts.append(make_bar_chart({p['label']: p['score'] for p in text_preds}, "Text Emotion Scores"))
99
+ if audio_preds:
100
+ charts.append(make_bar_chart({p['label']: p['score'] for p in audio_preds}, "Audio Emotion Scores"))
101
+ charts.append(make_bar_chart(fused['all_scores'], "Fused Emotion Scores"))
102
+
103
+ return final_emotion, charts
104
+
105
+ # ------------------------------
106
+ # Build Gradio interface
107
+ # ------------------------------
108
+ with gr.Blocks() as demo:
109
+ gr.Markdown("## 🎭 Multimodal Emotion Classification (Text + Speech)")
110
+
111
+ with gr.Row():
112
+ with gr.Column():
113
+ txt = gr.Textbox(label="Text input", placeholder="Type something emotional...")
114
+ aud = gr.Audio(type="filepath", label="Upload speech (wav/mp3)")
115
+ w1 = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Text weight (w_text)")
116
+ w2 = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Audio weight (w_audio)")
117
+ btn = gr.Button("Predict")
118
+ with gr.Column():
119
+ final_label = gr.Markdown(label="Predicted Emotion")
120
+ chart_output = gr.Plot(label="Emotion Scores")
121
+
122
+ # Button click triggers prediction
123
+ btn.click(fn=predict, inputs=[txt, aud, w1, w2], outputs=[final_label, chart_output])
124
+
125
+ demo.launch()