sachin7777777 commited on
Commit
3b2fd42
Β·
verified Β·
1 Parent(s): 7a07632

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -64
app.py CHANGED
@@ -4,7 +4,7 @@ import pandas as pd
4
  import plotly.express as px
5
 
6
  # ------------------------------
7
- # Load pretrained models (CPU-friendly)
8
  # ------------------------------
9
  text_classifier = pipeline(
10
  "text-classification",
@@ -12,12 +12,6 @@ text_classifier = pipeline(
12
  top_k=None # returns all scores
13
  )
14
 
15
- # Use a small, public audio model
16
- audio_classifier = pipeline(
17
- "audio-classification",
18
- model="superb/wav2vec2-small-superb-er" # small model
19
- )
20
-
21
  # ------------------------------
22
  # Map emotion to emoji
23
  # ------------------------------
@@ -28,42 +22,11 @@ EMOJI_MAP = {
28
  "joy": "πŸ˜„",
29
  "neutral": "😐",
30
  "sadness": "😒",
31
- "surprise": "😲",
32
- "hap": "πŸ˜„", # audio model labels
33
- "neu": "😐",
34
- "sad": "😒",
35
- "ang": "😑"
36
  }
37
 
38
  # ------------------------------
39
- # Fusion function
40
- # ------------------------------
41
- def fuse_predictions(text_preds=None, audio_preds=None, w_text=0.5, w_audio=0.5):
42
- labels = set()
43
- if text_preds:
44
- labels |= {p['label'] for p in text_preds}
45
- if audio_preds:
46
- labels |= {p['label'] for p in audio_preds}
47
- scores = {l: 0.0 for l in labels}
48
-
49
- def normalize(preds):
50
- total = sum(p['score'] for p in preds)
51
- return {p['label']: p['score']/total for p in preds}
52
-
53
- if text_preds:
54
- t_norm = normalize(text_preds)
55
- for l in labels:
56
- scores[l] += w_text * t_norm.get(l, 0)
57
- if audio_preds:
58
- a_norm = normalize(audio_preds)
59
- for l in labels:
60
- scores[l] += w_audio * a_norm.get(l, 0)
61
-
62
- best = max(scores.items(), key=lambda x: x[1]) if scores else ("none", 0)
63
- return {"fused_label": best[0], "fused_score": round(best[1], 3), "all_scores": scores}
64
-
65
- # ------------------------------
66
- # Bar chart function
67
  # ------------------------------
68
  def make_bar_chart(scores_dict, title="Emotion Scores"):
69
  df = pd.DataFrame({
@@ -80,41 +43,45 @@ def make_bar_chart(scores_dict, title="Emotion Scores"):
80
  # ------------------------------
81
  # Prediction function
82
  # ------------------------------
83
- def predict(text, audio, w_text, w_audio):
84
- text_preds, audio_preds = None, None
85
- if text:
86
- text_preds = text_classifier(text) # list of dicts
87
- if audio:
88
- audio_preds = audio_classifier(audio) # list of dicts
89
-
90
- fused = fuse_predictions(text_preds, audio_preds, w_text, w_audio)
91
-
92
- # Final emotion with emoji
93
- label = fused['fused_label']
94
- emoji = EMOJI_MAP.get(label, "")
95
- final_emotion = f"### Final Predicted Emotion: {label.upper()} {emoji} (score: {fused['fused_score']})"
96
-
97
- # Fused bar chart
98
- chart = make_bar_chart(fused['all_scores'], "Fused Emotion Scores")
99
- return final_emotion, chart
 
 
 
 
 
 
 
100
 
101
  # ------------------------------
102
  # Build Gradio interface
103
  # ------------------------------
104
  with gr.Blocks() as demo:
105
- gr.Markdown("## 🎭 Multimodal Emotion Classification (Text + Speech)")
106
 
107
  with gr.Row():
108
  with gr.Column():
109
  txt = gr.Textbox(label="Text input", placeholder="Type something emotional...")
110
- aud = gr.Audio(type="filepath", label="Upload speech (wav/mp3)")
111
- w1 = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Text weight (w_text)")
112
- w2 = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Audio weight (w_audio)")
113
  btn = gr.Button("Predict")
114
  with gr.Column():
115
- final_label = gr.Markdown(label="Predicted Emotion")
116
- chart_output = gr.Plot(label="Fused Emotion Scores")
117
 
118
- btn.click(fn=predict, inputs=[txt, aud, w1, w2], outputs=[final_label, chart_output])
119
 
120
  demo.launch()
 
4
  import plotly.express as px
5
 
6
  # ------------------------------
7
+ # Load pretrained text model
8
  # ------------------------------
9
  text_classifier = pipeline(
10
  "text-classification",
 
12
  top_k=None # returns all scores
13
  )
14
 
 
 
 
 
 
 
15
  # ------------------------------
16
  # Map emotion to emoji
17
  # ------------------------------
 
22
  "joy": "πŸ˜„",
23
  "neutral": "😐",
24
  "sadness": "😒",
25
+ "surprise": "😲"
 
 
 
 
26
  }
27
 
28
  # ------------------------------
29
+ # Create bar chart
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # ------------------------------
31
  def make_bar_chart(scores_dict, title="Emotion Scores"):
32
  df = pd.DataFrame({
 
43
  # ------------------------------
44
  # Prediction function
45
  # ------------------------------
46
+ def predict(text, w_text=1.0):
47
+ if not text:
48
+ return "Please enter text.", None
49
+ preds = text_classifier(text)[0] # get all scores
50
+ scores = {p['label']: p['score'] for p in preds}
51
+ best_label = max(scores, key=scores.get)
52
+ emoji = EMOJI_MAP.get(best_label, "")
53
+
54
+ # Animate emoji with simple bouncing
55
+ final_emotion_html = f"""
56
+ <div style="font-size:80px; text-align:center; animation: bounce 1s infinite;">
57
+ {emoji}
58
+ </div>
59
+ <h3 style="text-align:center;">{best_label.upper()} (score: {scores[best_label]:.2f})</h3>
60
+ <style>
61
+ @keyframes bounce {{
62
+ 0%, 20%, 50%, 80%, 100% {{transform: translateY(0);}}
63
+ 40% {{transform: translateY(-20px);}}
64
+ 60% {{transform: translateY(-10px);}}
65
+ }}
66
+ </style>
67
+ """
68
+ chart = make_bar_chart(scores, "Text Emotion Scores")
69
+ return final_emotion_html, chart
70
 
71
  # ------------------------------
72
  # Build Gradio interface
73
  # ------------------------------
74
  with gr.Blocks() as demo:
75
+ gr.Markdown("## 🎭 Text Emotion Classification with Emoji Animation")
76
 
77
  with gr.Row():
78
  with gr.Column():
79
  txt = gr.Textbox(label="Text input", placeholder="Type something emotional...")
 
 
 
80
  btn = gr.Button("Predict")
81
  with gr.Column():
82
+ final_label = gr.HTML(label="Predicted Emotion")
83
+ chart_output = gr.Plot(label="Emotion Scores")
84
 
85
+ btn.click(fn=predict, inputs=[txt], outputs=[final_label, chart_output])
86
 
87
  demo.launch()