Spaces:
Sleeping
Sleeping
vickeee465
commited on
Commit
·
7fb1f06
1
Parent(s):
51db79c
giving a try to crazy imshow thingie
Browse files
app.py
CHANGED
|
@@ -9,6 +9,7 @@ from transformers import AutoTokenizer
|
|
| 9 |
import gradio as gr
|
| 10 |
import matplotlib.pyplot as plt
|
| 11 |
from matplotlib.colors import LinearSegmentedColormap
|
|
|
|
| 12 |
import plotly.express as px
|
| 13 |
import seaborn as sns
|
| 14 |
|
|
@@ -99,44 +100,31 @@ def prepare_heatmap_data(data):
|
|
| 99 |
heatmap_data.columns = [item["sentence"][:18]+"..." for item in data]
|
| 100 |
return heatmap_data
|
| 101 |
|
| 102 |
-
|
| 103 |
def plot_emotion_heatmap(heatmap_data):
|
| 104 |
-
#
|
| 105 |
-
|
| 106 |
-
# Normalize all values to [0, 1] for each emotion
|
| 107 |
normalized_data = heatmap_data.copy()
|
| 108 |
for row in heatmap_data.index:
|
| 109 |
max_val = heatmap_data.loc[row].max()
|
| 110 |
-
if max_val > 0
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
vmin=0,
|
| 131 |
-
vmax=1,
|
| 132 |
-
ax=ax if i == 0 else ax, # reuse same axis
|
| 133 |
-
)
|
| 134 |
-
|
| 135 |
-
# Format axis
|
| 136 |
-
ax.set_xticks(np.arange(len(heatmap_data.columns)) + 0.5)
|
| 137 |
-
ax.set_xticklabels(heatmap_data.columns, rotation=0, ha='center')
|
| 138 |
-
ax.set_yticks(np.arange(len(heatmap_data.index)) + 0.5)
|
| 139 |
-
ax.set_yticklabels(heatmap_data.index, rotation=0, ha='right')
|
| 140 |
|
| 141 |
ax.set_xlabel("Sentences")
|
| 142 |
ax.set_ylabel("Emotions")
|
|
|
|
| 9 |
import gradio as gr
|
| 10 |
import matplotlib.pyplot as plt
|
| 11 |
from matplotlib.colors import LinearSegmentedColormap
|
| 12 |
+
import matplotlib.colors as mcolors
|
| 13 |
import plotly.express as px
|
| 14 |
import seaborn as sns
|
| 15 |
|
|
|
|
| 100 |
heatmap_data.columns = [item["sentence"][:18]+"..." for item in data]
|
| 101 |
return heatmap_data
|
| 102 |
|
|
|
|
| 103 |
def plot_emotion_heatmap(heatmap_data):
|
| 104 |
+
# Normalize values to [0, 1] per row (emotion)
|
|
|
|
|
|
|
| 105 |
normalized_data = heatmap_data.copy()
|
| 106 |
for row in heatmap_data.index:
|
| 107 |
max_val = heatmap_data.loc[row].max()
|
| 108 |
+
normalized_data.loc[row] = heatmap_data.loc[row] / max_val if max_val > 0 else 0
|
| 109 |
+
|
| 110 |
+
# Build custom RGB color matrix
|
| 111 |
+
color_matrix = np.empty((len(normalized_data.index), len(normalized_data.columns), 3))
|
| 112 |
+
for i, emotion in enumerate(normalized_data.index):
|
| 113 |
+
base_rgb = mcolors.to_rgb(emotion_colors[emotion])
|
| 114 |
+
for j, val in enumerate(normalized_data.loc[emotion]):
|
| 115 |
+
# Linear interpolation from white to base color
|
| 116 |
+
color = tuple(1 - val * (1 - c) for c in base_rgb)
|
| 117 |
+
color_matrix[i, j] = color
|
| 118 |
+
|
| 119 |
+
fig, ax = plt.subplots(figsize=(len(normalized_data.columns) * 0.5 + 4, len(normalized_data.index) * 0.5 + 2))
|
| 120 |
+
|
| 121 |
+
ax.imshow(color_matrix, aspect='auto')
|
| 122 |
+
|
| 123 |
+
# Ticks and labels
|
| 124 |
+
ax.set_xticks(np.arange(len(normalized_data.columns)))
|
| 125 |
+
ax.set_xticklabels(normalized_data.columns, rotation=0, ha='center')
|
| 126 |
+
ax.set_yticks(np.arange(len(normalized_data.index)))
|
| 127 |
+
ax.set_yticklabels(normalized_data.index, rotation=0, ha='right')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
ax.set_xlabel("Sentences")
|
| 130 |
ax.set_ylabel("Emotions")
|