OliverPerrin commited on
Commit
cd865e2
·
1 Parent(s): e17e016

Emotion Fix in Gradio

Browse files
Files changed (1) hide show
  1. scripts/demo_gradio.py +33 -49
scripts/demo_gradio.py CHANGED
@@ -6,6 +6,7 @@ import json
6
  import sys
7
  from pathlib import Path
8
  from typing import Iterable, Sequence
 
9
  import gradio as gr
10
  from gradio.themes import Soft
11
  import matplotlib.pyplot as plt
@@ -70,51 +71,43 @@ def map_compression_to_length(compression: int, max_model_length: int = 512):
70
  return int(ratio * max_model_length)
71
 
72
  def predict(text: str, compression: int):
73
- """
74
- Main predcition function for the Gradio interface.
75
- Args:
76
- text: Text to process
77
- compression: Compression percentage (20-80)
78
- Returns:
79
- Tuple of (summary_html, emotion_plot, topic_output, attention_fig, download_data)
80
- """
81
  if not text or not text.strip():
82
  return (
83
  "Please enter some text to analyze.",
84
  None,
85
  "No topic prediction available",
86
  None,
87
- gr.update(value=None, visible=False),
88
  )
89
  try:
90
  pipeline = get_pipeline()
91
  max_len = map_compression_to_length(compression)
92
- logger.info(f"Generating summary with max length of {max_len}")
93
-
94
- # Get the predictions
95
  summary = pipeline.summarize([text], max_length=max_len)[0]
96
  emotions = pipeline.predict_emotions([text])[0]
97
  topic = pipeline.predict_topics([text])[0]
98
-
99
  summary_html = format_summary(text, summary)
100
  emotion_plot = create_emotion_plot(emotions)
101
  topic_output = format_topic(topic)
102
  attention_fig = create_attention_heatmap(text, summary, pipeline)
103
- download_data = prepare_download(text, summary, emotions, topic)
 
104
 
105
- return summary_html, emotion_plot, topic_output, attention_fig, gr.update(
106
- value=download_data,
107
- visible=True,
108
- )
109
-
110
- except Exception as e:
111
- logger.error(f"Prediction error: {e}", exc_info=True)
112
  error_msg = "Prediction failed. Check logs for details."
113
- return error_msg, None, "Error", None, gr.update(value=None, visible=False)
114
-
115
- def format_summary(original: str, summary:str) ->str:
116
- """Format original and summary text for display"""
117
- html = f"""
 
118
  <div style="padding: 10px; border-radius: 5px;">
119
  <h3>Original Text</h3>
120
  <p style="background-color: #f0f0f0; padding: 10px; border-radius: 3px;">
@@ -126,17 +119,15 @@ def format_summary(original: str, summary:str) ->str:
126
  </p>
127
  </div>
128
  """
129
- return html
130
 
131
- def create_emotion_plot(emotions: EmotionPrediction | dict[str, Sequence[float] | Sequence[str]]) -> Figure | None:
132
- """
133
- Create bar plot for emotion predictions.
134
- Args:
135
- emotions: Dict with 'labels' and 'scores' keys
136
- """
137
  if isinstance(emotions, EmotionPrediction):
138
- labels = emotions.labels
139
- scores = emotions.scores
140
  else:
141
  labels = list(emotions.get("labels", []))
142
  scores = list(emotions.get("scores", []))
@@ -144,10 +135,7 @@ def create_emotion_plot(emotions: EmotionPrediction | dict[str, Sequence[float]
144
  if not labels or not scores:
145
  return None
146
 
147
- df = pd.DataFrame({
148
- "Emotion": labels,
149
- "Probability": scores,
150
- })
151
  fig, ax = plt.subplots(figsize=(8, 5))
152
  colors = sns.color_palette("Set2", len(labels))
153
  bars = ax.barh(df["Emotion"], df["Probability"], color=colors)
@@ -169,27 +157,23 @@ def create_emotion_plot(emotions: EmotionPrediction | dict[str, Sequence[float]
169
  plt.tight_layout()
170
  return fig
171
 
 
172
  def format_topic(topic: TopicPrediction | dict[str, float | str]) -> str:
173
- """
174
- Format topic prediction output.
175
-
176
- Args:
177
- topic: Dict with 'label' and 'score' keys
178
- """
179
  if isinstance(topic, TopicPrediction):
180
  label = topic.label
181
  score = topic.confidence
182
  else:
183
  label = str(topic.get("label", "Unknown"))
184
  score = float(topic.get("score", 0.0))
185
- output = f"""
 
186
  ### Predicted Topic
187
-
188
  **{label}**
189
-
190
  Confidence: {score:.2%}
191
  """
192
- return output
193
 
194
  def _clean_tokens(tokens: Iterable[str]) -> list[str]:
195
  cleaned: list[str] = []
 
6
  import sys
7
  from pathlib import Path
8
  from typing import Iterable, Sequence
9
+
10
  import gradio as gr
11
  from gradio.themes import Soft
12
  import matplotlib.pyplot as plt
 
71
  return int(ratio * max_model_length)
72
 
73
  def predict(text: str, compression: int):
74
+ """Run the full pipeline and prepare Gradio outputs."""
75
+ hidden_download = gr.update(value=None, visible=False)
 
 
 
 
 
 
76
  if not text or not text.strip():
77
  return (
78
  "Please enter some text to analyze.",
79
  None,
80
  "No topic prediction available",
81
  None,
82
+ hidden_download,
83
  )
84
  try:
85
  pipeline = get_pipeline()
86
  max_len = map_compression_to_length(compression)
87
+ logger.info("Generating summary with max length of %s", max_len)
88
+
 
89
  summary = pipeline.summarize([text], max_length=max_len)[0]
90
  emotions = pipeline.predict_emotions([text])[0]
91
  topic = pipeline.predict_topics([text])[0]
92
+
93
  summary_html = format_summary(text, summary)
94
  emotion_plot = create_emotion_plot(emotions)
95
  topic_output = format_topic(topic)
96
  attention_fig = create_attention_heatmap(text, summary, pipeline)
97
+ download_bytes = prepare_download(text, summary, emotions, topic)
98
+ download_update = gr.update(value=download_bytes, visible=True)
99
 
100
+ return summary_html, emotion_plot, topic_output, attention_fig, download_update
101
+
102
+ except Exception as exc: # pragma: no cover - surfaced in UI
103
+ logger.error("Prediction error: %s", exc, exc_info=True)
 
 
 
104
  error_msg = "Prediction failed. Check logs for details."
105
+ return error_msg, None, "Error", None, hidden_download
106
+
107
+
108
+ def format_summary(original: str, summary: str) -> str:
109
+ """Format original and summary text for display."""
110
+ return f"""
111
  <div style="padding: 10px; border-radius: 5px;">
112
  <h3>Original Text</h3>
113
  <p style="background-color: #f0f0f0; padding: 10px; border-radius: 3px;">
 
119
  </p>
120
  </div>
121
  """
 
122
 
123
+
124
+ def create_emotion_plot(
125
+ emotions: EmotionPrediction | dict[str, Sequence[float] | Sequence[str]]
126
+ ) -> Figure | None:
127
+ """Create a horizontal bar chart for emotion predictions."""
 
128
  if isinstance(emotions, EmotionPrediction):
129
+ labels = list(emotions.labels)
130
+ scores = list(emotions.scores)
131
  else:
132
  labels = list(emotions.get("labels", []))
133
  scores = list(emotions.get("scores", []))
 
135
  if not labels or not scores:
136
  return None
137
 
138
+ df = pd.DataFrame({"Emotion": labels, "Probability": scores})
 
 
 
139
  fig, ax = plt.subplots(figsize=(8, 5))
140
  colors = sns.color_palette("Set2", len(labels))
141
  bars = ax.barh(df["Emotion"], df["Probability"], color=colors)
 
157
  plt.tight_layout()
158
  return fig
159
 
160
+
161
  def format_topic(topic: TopicPrediction | dict[str, float | str]) -> str:
162
+ """Format topic prediction output as markdown."""
 
 
 
 
 
163
  if isinstance(topic, TopicPrediction):
164
  label = topic.label
165
  score = topic.confidence
166
  else:
167
  label = str(topic.get("label", "Unknown"))
168
  score = float(topic.get("score", 0.0))
169
+
170
+ return f"""
171
  ### Predicted Topic
172
+
173
  **{label}**
174
+
175
  Confidence: {score:.2%}
176
  """
 
177
 
178
  def _clean_tokens(tokens: Iterable[str]) -> list[str]:
179
  cleaned: list[str] = []