Spaces:
Sleeping
Sleeping
Commit ·
cd865e2
1
Parent(s): e17e016
Emotion Fix in Gradio
Browse files- scripts/demo_gradio.py +33 -49
scripts/demo_gradio.py
CHANGED
|
@@ -6,6 +6,7 @@ import json
|
|
| 6 |
import sys
|
| 7 |
from pathlib import Path
|
| 8 |
from typing import Iterable, Sequence
|
|
|
|
| 9 |
import gradio as gr
|
| 10 |
from gradio.themes import Soft
|
| 11 |
import matplotlib.pyplot as plt
|
|
@@ -70,51 +71,43 @@ def map_compression_to_length(compression: int, max_model_length: int = 512):
|
|
| 70 |
return int(ratio * max_model_length)
|
| 71 |
|
| 72 |
def predict(text: str, compression: int):
|
| 73 |
-
"""
|
| 74 |
-
|
| 75 |
-
Args:
|
| 76 |
-
text: Text to process
|
| 77 |
-
compression: Compression percentage (20-80)
|
| 78 |
-
Returns:
|
| 79 |
-
Tuple of (summary_html, emotion_plot, topic_output, attention_fig, download_data)
|
| 80 |
-
"""
|
| 81 |
if not text or not text.strip():
|
| 82 |
return (
|
| 83 |
"Please enter some text to analyze.",
|
| 84 |
None,
|
| 85 |
"No topic prediction available",
|
| 86 |
None,
|
| 87 |
-
|
| 88 |
)
|
| 89 |
try:
|
| 90 |
pipeline = get_pipeline()
|
| 91 |
max_len = map_compression_to_length(compression)
|
| 92 |
-
logger.info(
|
| 93 |
-
|
| 94 |
-
# Get the predictions
|
| 95 |
summary = pipeline.summarize([text], max_length=max_len)[0]
|
| 96 |
emotions = pipeline.predict_emotions([text])[0]
|
| 97 |
topic = pipeline.predict_topics([text])[0]
|
| 98 |
-
|
| 99 |
summary_html = format_summary(text, summary)
|
| 100 |
emotion_plot = create_emotion_plot(emotions)
|
| 101 |
topic_output = format_topic(topic)
|
| 102 |
attention_fig = create_attention_heatmap(text, summary, pipeline)
|
| 103 |
-
|
|
|
|
| 104 |
|
| 105 |
-
return summary_html, emotion_plot, topic_output, attention_fig,
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
)
|
| 109 |
-
|
| 110 |
-
except Exception as e:
|
| 111 |
-
logger.error(f"Prediction error: {e}", exc_info=True)
|
| 112 |
error_msg = "Prediction failed. Check logs for details."
|
| 113 |
-
return error_msg, None, "Error", None,
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
|
|
|
| 118 |
<div style="padding: 10px; border-radius: 5px;">
|
| 119 |
<h3>Original Text</h3>
|
| 120 |
<p style="background-color: #f0f0f0; padding: 10px; border-radius: 3px;">
|
|
@@ -126,17 +119,15 @@ def format_summary(original: str, summary:str) ->str:
|
|
| 126 |
</p>
|
| 127 |
</div>
|
| 128 |
"""
|
| 129 |
-
return html
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
"""
|
| 137 |
if isinstance(emotions, EmotionPrediction):
|
| 138 |
-
labels = emotions.labels
|
| 139 |
-
scores = emotions.scores
|
| 140 |
else:
|
| 141 |
labels = list(emotions.get("labels", []))
|
| 142 |
scores = list(emotions.get("scores", []))
|
|
@@ -144,10 +135,7 @@ def create_emotion_plot(emotions: EmotionPrediction | dict[str, Sequence[float]
|
|
| 144 |
if not labels or not scores:
|
| 145 |
return None
|
| 146 |
|
| 147 |
-
df = pd.DataFrame({
|
| 148 |
-
"Emotion": labels,
|
| 149 |
-
"Probability": scores,
|
| 150 |
-
})
|
| 151 |
fig, ax = plt.subplots(figsize=(8, 5))
|
| 152 |
colors = sns.color_palette("Set2", len(labels))
|
| 153 |
bars = ax.barh(df["Emotion"], df["Probability"], color=colors)
|
|
@@ -169,27 +157,23 @@ def create_emotion_plot(emotions: EmotionPrediction | dict[str, Sequence[float]
|
|
| 169 |
plt.tight_layout()
|
| 170 |
return fig
|
| 171 |
|
|
|
|
| 172 |
def format_topic(topic: TopicPrediction | dict[str, float | str]) -> str:
|
| 173 |
-
"""
|
| 174 |
-
Format topic prediction output.
|
| 175 |
-
|
| 176 |
-
Args:
|
| 177 |
-
topic: Dict with 'label' and 'score' keys
|
| 178 |
-
"""
|
| 179 |
if isinstance(topic, TopicPrediction):
|
| 180 |
label = topic.label
|
| 181 |
score = topic.confidence
|
| 182 |
else:
|
| 183 |
label = str(topic.get("label", "Unknown"))
|
| 184 |
score = float(topic.get("score", 0.0))
|
| 185 |
-
|
|
|
|
| 186 |
### Predicted Topic
|
| 187 |
-
|
| 188 |
**{label}**
|
| 189 |
-
|
| 190 |
Confidence: {score:.2%}
|
| 191 |
"""
|
| 192 |
-
return output
|
| 193 |
|
| 194 |
def _clean_tokens(tokens: Iterable[str]) -> list[str]:
|
| 195 |
cleaned: list[str] = []
|
|
|
|
| 6 |
import sys
|
| 7 |
from pathlib import Path
|
| 8 |
from typing import Iterable, Sequence
|
| 9 |
+
|
| 10 |
import gradio as gr
|
| 11 |
from gradio.themes import Soft
|
| 12 |
import matplotlib.pyplot as plt
|
|
|
|
| 71 |
return int(ratio * max_model_length)
|
| 72 |
|
| 73 |
def predict(text: str, compression: int):
|
| 74 |
+
"""Run the full pipeline and prepare Gradio outputs."""
|
| 75 |
+
hidden_download = gr.update(value=None, visible=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
if not text or not text.strip():
|
| 77 |
return (
|
| 78 |
"Please enter some text to analyze.",
|
| 79 |
None,
|
| 80 |
"No topic prediction available",
|
| 81 |
None,
|
| 82 |
+
hidden_download,
|
| 83 |
)
|
| 84 |
try:
|
| 85 |
pipeline = get_pipeline()
|
| 86 |
max_len = map_compression_to_length(compression)
|
| 87 |
+
logger.info("Generating summary with max length of %s", max_len)
|
| 88 |
+
|
|
|
|
| 89 |
summary = pipeline.summarize([text], max_length=max_len)[0]
|
| 90 |
emotions = pipeline.predict_emotions([text])[0]
|
| 91 |
topic = pipeline.predict_topics([text])[0]
|
| 92 |
+
|
| 93 |
summary_html = format_summary(text, summary)
|
| 94 |
emotion_plot = create_emotion_plot(emotions)
|
| 95 |
topic_output = format_topic(topic)
|
| 96 |
attention_fig = create_attention_heatmap(text, summary, pipeline)
|
| 97 |
+
download_bytes = prepare_download(text, summary, emotions, topic)
|
| 98 |
+
download_update = gr.update(value=download_bytes, visible=True)
|
| 99 |
|
| 100 |
+
return summary_html, emotion_plot, topic_output, attention_fig, download_update
|
| 101 |
+
|
| 102 |
+
except Exception as exc: # pragma: no cover - surfaced in UI
|
| 103 |
+
logger.error("Prediction error: %s", exc, exc_info=True)
|
|
|
|
|
|
|
|
|
|
| 104 |
error_msg = "Prediction failed. Check logs for details."
|
| 105 |
+
return error_msg, None, "Error", None, hidden_download
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def format_summary(original: str, summary: str) -> str:
|
| 109 |
+
"""Format original and summary text for display."""
|
| 110 |
+
return f"""
|
| 111 |
<div style="padding: 10px; border-radius: 5px;">
|
| 112 |
<h3>Original Text</h3>
|
| 113 |
<p style="background-color: #f0f0f0; padding: 10px; border-radius: 3px;">
|
|
|
|
| 119 |
</p>
|
| 120 |
</div>
|
| 121 |
"""
|
|
|
|
| 122 |
|
| 123 |
+
|
| 124 |
+
def create_emotion_plot(
|
| 125 |
+
emotions: EmotionPrediction | dict[str, Sequence[float] | Sequence[str]]
|
| 126 |
+
) -> Figure | None:
|
| 127 |
+
"""Create a horizontal bar chart for emotion predictions."""
|
|
|
|
| 128 |
if isinstance(emotions, EmotionPrediction):
|
| 129 |
+
labels = list(emotions.labels)
|
| 130 |
+
scores = list(emotions.scores)
|
| 131 |
else:
|
| 132 |
labels = list(emotions.get("labels", []))
|
| 133 |
scores = list(emotions.get("scores", []))
|
|
|
|
| 135 |
if not labels or not scores:
|
| 136 |
return None
|
| 137 |
|
| 138 |
+
df = pd.DataFrame({"Emotion": labels, "Probability": scores})
|
|
|
|
|
|
|
|
|
|
| 139 |
fig, ax = plt.subplots(figsize=(8, 5))
|
| 140 |
colors = sns.color_palette("Set2", len(labels))
|
| 141 |
bars = ax.barh(df["Emotion"], df["Probability"], color=colors)
|
|
|
|
| 157 |
plt.tight_layout()
|
| 158 |
return fig
|
| 159 |
|
| 160 |
+
|
| 161 |
def format_topic(topic: TopicPrediction | dict[str, float | str]) -> str:
|
| 162 |
+
"""Format topic prediction output as markdown."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
if isinstance(topic, TopicPrediction):
|
| 164 |
label = topic.label
|
| 165 |
score = topic.confidence
|
| 166 |
else:
|
| 167 |
label = str(topic.get("label", "Unknown"))
|
| 168 |
score = float(topic.get("score", 0.0))
|
| 169 |
+
|
| 170 |
+
return f"""
|
| 171 |
### Predicted Topic
|
| 172 |
+
|
| 173 |
**{label}**
|
| 174 |
+
|
| 175 |
Confidence: {score:.2%}
|
| 176 |
"""
|
|
|
|
| 177 |
|
| 178 |
def _clean_tokens(tokens: Iterable[str]) -> list[str]:
|
| 179 |
cleaned: list[str] = []
|