thewh1teagle commited on
Update WER plot with improved model metrics and visualization enhancements
Browse files- comparison/wer_plot.py +105 -64
comparison/wer_plot.py
CHANGED
|
@@ -7,110 +7,151 @@ import matplotlib.ticker as ticker
|
|
| 7 |
|
| 8 |
# Data for the models
|
| 9 |
models = [
|
| 10 |
-
("Piper", 0.
|
| 11 |
-
("StyleTTS2", 0.
|
| 12 |
-
("HebTTS", 0.
|
| 13 |
-
("LoTHM", 0.
|
| 14 |
-
("MMS", 0.
|
| 15 |
-
("SASPEECH", 0.
|
| 16 |
-
("Robo-Shaul", 0.
|
| 17 |
-
("Google", 0.
|
| 18 |
-
("OpenAI", 0.
|
| 19 |
]
|
| 20 |
|
| 21 |
# Filter out models with None values for WER or RTF
|
| 22 |
filtered = [m for m in models if m[1] is not None and m[2] is not None]
|
| 23 |
|
| 24 |
-
# Create the figure and axes
|
| 25 |
-
fig, ax = plt.subplots(figsize=(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
# Plot each model
|
| 28 |
for name, wer, rtf, category in filtered:
|
| 29 |
# Determine color based on category
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
elif category == 'Ours':
|
| 33 |
-
color = 'red'
|
| 34 |
-
else:
|
| 35 |
-
color = 'blue'
|
| 36 |
-
|
| 37 |
# Determine size and weight for our models
|
| 38 |
-
size =
|
| 39 |
weight = 'bold' if category == 'Ours' else 'normal'
|
| 40 |
-
|
| 41 |
|
| 42 |
# Create label for the point
|
| 43 |
label = f"Ours ({name})" if category == 'Ours' else name
|
| 44 |
|
| 45 |
# Plot the scatter point
|
| 46 |
-
ax.scatter(rtf, wer, s=size, c=color, edgecolors='black',
|
|
|
|
| 47 |
|
| 48 |
-
# Adjust text position for
|
| 49 |
if name == "HebTTS":
|
| 50 |
-
x_text = rtf * 0.
|
|
|
|
| 51 |
ha = 'right'
|
|
|
|
| 52 |
elif name == "Google":
|
| 53 |
-
x_text = rtf * 1.
|
| 54 |
-
y_text = wer - 0.
|
| 55 |
ha = 'left'
|
|
|
|
| 56 |
elif name == 'LoTHM':
|
| 57 |
-
x_text = rtf * 0.
|
|
|
|
|
|
|
|
|
|
| 58 |
elif name == "OpenAI":
|
| 59 |
-
x_text = rtf
|
| 60 |
-
y_text = wer - 0.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
ha = 'right'
|
|
|
|
| 62 |
else:
|
| 63 |
x_text = rtf * 1.15
|
|
|
|
| 64 |
ha = 'left'
|
|
|
|
| 65 |
|
| 66 |
# Add text label for each point
|
| 67 |
-
|
|
|
|
|
|
|
| 68 |
|
| 69 |
# Set x-axis to log scale and format it
|
| 70 |
ax.set_xscale('log')
|
| 71 |
-
ax.tick_params(axis='both', which='major', labelsize=
|
|
|
|
| 72 |
ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
|
| 73 |
ax.xaxis.get_major_formatter().set_scientific(False)
|
| 74 |
ax.xaxis.get_major_formatter().set_useOffset(False)
|
| 75 |
|
| 76 |
-
#
|
| 77 |
-
ax.
|
| 78 |
-
ax.set_ylabel("WER (lower is more accurate)", fontsize=22)
|
| 79 |
|
| 80 |
-
#
|
| 81 |
-
ax.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
# Adjust layout to prevent labels from being cut off
|
| 84 |
plt.tight_layout()
|
| 85 |
|
| 86 |
-
# Extend x-axis limits by
|
| 87 |
x_min, x_max = ax.get_xlim()
|
| 88 |
-
ax.set_xlim(x_min, x_max * 1.
|
| 89 |
-
|
| 90 |
-
# -
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
#
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
# End point (much closer to create shorter arrow with steeper angle)
|
| 102 |
-
arrow_end_x = x_lims[1] * 0.0006 # 60% across the x-axis (shorter horizontal distance)
|
| 103 |
-
arrow_end_y = y_lims[1] * 0.43 # 60% up the y-axis (steeper vertical drop)
|
| 104 |
-
|
| 105 |
-
# Draw arrow pointing from upper-right toward bottom-left
|
| 106 |
-
ax.annotate('',
|
| 107 |
-
xy=(arrow_end_x, arrow_end_y), # End point (arrow head)
|
| 108 |
-
xytext=(arrow_start_x, arrow_start_y), # Start point (arrow tail)
|
| 109 |
-
arrowprops=dict(facecolor='gray', shrink=0.05, width=0.5, headwidth=8, alpha=0.3),
|
| 110 |
-
annotation_clip=False,
|
| 111 |
-
zorder=1) # Behind circles (3) and text (4)
|
| 112 |
-
|
| 113 |
-
# Clear any existing title and save the figure
|
| 114 |
-
plt.title("")
|
| 115 |
-
plt.savefig("plot.png", dpi=1200)
|
| 116 |
plt.show()
|
|
|
|
| 7 |
|
| 8 |
# Data for the models
|
| 9 |
models = [
|
| 10 |
+
("Piper", 0.11, 0.09, "Ours"),
|
| 11 |
+
("StyleTTS2", 0.07, 0.50, "Ours"),
|
| 12 |
+
("HebTTS", 0.24, 25.44, "Open"),
|
| 13 |
+
("LoTHM", 0.49, 84.75, "Open"),
|
| 14 |
+
("MMS", 0.20, 0.21, "Open"),
|
| 15 |
+
("SASPEECH", 0.11, 0.16, "Open"),
|
| 16 |
+
("Robo-Shaul", 0.08, 1.58, "Open"),
|
| 17 |
+
("Google", 0.04, 4.08, "Proprietary"),
|
| 18 |
+
("OpenAI", 0.05, 1.60, "Proprietary"),
|
| 19 |
]
|
| 20 |
|
| 21 |
# Filter out models with None values for WER or RTF
|
| 22 |
filtered = [m for m in models if m[1] is not None and m[2] is not None]
|
| 23 |
|
| 24 |
+
# Create the figure and axes with better sizing
|
| 25 |
+
fig, ax = plt.subplots(figsize=(12, 8))
|
| 26 |
+
|
| 27 |
+
# Color mapping with fancy colors
|
| 28 |
+
colors = {'Ours': '#e74c3c', 'Open': '#3498db', 'Proprietary': '#f39c12'}
|
| 29 |
+
legend_elements = []
|
| 30 |
|
| 31 |
# Plot each model
|
| 32 |
for name, wer, rtf, category in filtered:
|
| 33 |
# Determine color based on category
|
| 34 |
+
color = colors[category]
|
| 35 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
# Determine size and weight for our models
|
| 37 |
+
size = 180 # Same size for all models
|
| 38 |
weight = 'bold' if category == 'Ours' else 'normal'
|
| 39 |
+
edgewidth = 2 if category == 'Ours' else 1.5
|
| 40 |
|
| 41 |
# Create label for the point
|
| 42 |
label = f"Ours ({name})" if category == 'Ours' else name
|
| 43 |
|
| 44 |
# Plot the scatter point
|
| 45 |
+
scatter = ax.scatter(rtf, wer, s=size, c=color, edgecolors='black',
|
| 46 |
+
linewidths=edgewidth, zorder=3, alpha=0.8)
|
| 47 |
|
| 48 |
+
# Adjust text position for each model
|
| 49 |
if name == "HebTTS":
|
| 50 |
+
x_text = rtf * 0.75
|
| 51 |
+
y_text = wer
|
| 52 |
ha = 'right'
|
| 53 |
+
va = 'center'
|
| 54 |
elif name == "Google":
|
| 55 |
+
x_text = rtf * 1.2
|
| 56 |
+
y_text = wer - 0.008
|
| 57 |
ha = 'left'
|
| 58 |
+
va = 'center'
|
| 59 |
elif name == 'LoTHM':
|
| 60 |
+
x_text = rtf * 0.85
|
| 61 |
+
y_text = wer
|
| 62 |
+
ha = 'right'
|
| 63 |
+
va = 'center'
|
| 64 |
elif name == "OpenAI":
|
| 65 |
+
x_text = rtf
|
| 66 |
+
y_text = wer - 0.012
|
| 67 |
+
ha = 'center'
|
| 68 |
+
va = 'top'
|
| 69 |
+
elif name == "Robo-Shaul":
|
| 70 |
+
x_text = rtf * 1.2
|
| 71 |
+
y_text = wer + 0.008
|
| 72 |
+
ha = 'left'
|
| 73 |
+
va = 'center'
|
| 74 |
+
elif name == "Piper":
|
| 75 |
+
x_text = rtf * 0.9
|
| 76 |
+
y_text = wer - 0.022
|
| 77 |
+
ha = 'left'
|
| 78 |
+
va = 'top'
|
| 79 |
+
elif name == "StyleTTS2":
|
| 80 |
+
x_text = rtf * 0.8
|
| 81 |
+
y_text = wer - 0.018
|
| 82 |
+
ha = 'center'
|
| 83 |
+
va = 'top'
|
| 84 |
+
elif name == "SASPEECH":
|
| 85 |
+
x_text = rtf * 1.2
|
| 86 |
+
y_text = wer + 0.008
|
| 87 |
+
ha = 'left'
|
| 88 |
+
va = 'center'
|
| 89 |
+
elif name == "MMS":
|
| 90 |
+
x_text = rtf * 0.85
|
| 91 |
+
y_text = wer + 0.015
|
| 92 |
ha = 'right'
|
| 93 |
+
va = 'bottom'
|
| 94 |
else:
|
| 95 |
x_text = rtf * 1.15
|
| 96 |
+
y_text = wer
|
| 97 |
ha = 'left'
|
| 98 |
+
va = 'center'
|
| 99 |
|
| 100 |
# Add text label for each point
|
| 101 |
+
fontsize = 20 if category == 'Ours' else 22
|
| 102 |
+
ax.text(x_text, y_text, label, fontsize=fontsize, ha=ha, va=va,
|
| 103 |
+
color='black', weight=weight, zorder=4)
|
| 104 |
|
| 105 |
# Set x-axis to log scale and format it
|
| 106 |
ax.set_xscale('log')
|
| 107 |
+
ax.tick_params(axis='both', which='major', labelsize=14)
|
| 108 |
+
ax.tick_params(axis='both', which='minor', labelsize=12)
|
| 109 |
ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
|
| 110 |
ax.xaxis.get_major_formatter().set_scientific(False)
|
| 111 |
ax.xaxis.get_major_formatter().set_useOffset(False)
|
| 112 |
|
| 113 |
+
# Add minor ticks for better readability
|
| 114 |
+
ax.xaxis.set_minor_locator(ticker.LogLocator(base=10.0, subs=np.arange(2, 10) * 0.1))
|
|
|
|
| 115 |
|
| 116 |
+
# Set axis labels with larger font
|
| 117 |
+
ax.set_xlabel("RTF (lower is faster)", fontsize=18, fontweight='bold')
|
| 118 |
+
ax.set_ylabel("WER (lower is more accurate)", fontsize=18, fontweight='bold')
|
| 119 |
+
|
| 120 |
+
# Remove title
|
| 121 |
+
|
| 122 |
+
# Add subtle grid
|
| 123 |
+
ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
|
| 124 |
+
ax.set_axisbelow(True)
|
| 125 |
+
|
| 126 |
+
# Create custom legend
|
| 127 |
+
from matplotlib.patches import Patch
|
| 128 |
+
legend_elements = [
|
| 129 |
+
plt.scatter([], [], c=colors['Ours'], s=150, edgecolors='black',
|
| 130 |
+
linewidths=1.5, label='Our Models'),
|
| 131 |
+
plt.scatter([], [], c=colors['Open'], s=120, edgecolors='black',
|
| 132 |
+
linewidths=1.5, label='Open Source'),
|
| 133 |
+
plt.scatter([], [], c=colors['Proprietary'], s=120, edgecolors='black',
|
| 134 |
+
linewidths=1.5, label='Proprietary')
|
| 135 |
+
]
|
| 136 |
+
|
| 137 |
+
ax.legend(handles=legend_elements, loc='upper right', fontsize=14,
|
| 138 |
+
frameon=True, fancybox=True, shadow=True)
|
| 139 |
|
| 140 |
# Adjust layout to prevent labels from being cut off
|
| 141 |
plt.tight_layout()
|
| 142 |
|
| 143 |
+
# Extend x-axis limits by 30% to make space for labels
|
| 144 |
x_min, x_max = ax.get_xlim()
|
| 145 |
+
ax.set_xlim(x_min, x_max * 1.3)
|
| 146 |
+
|
| 147 |
+
# Extend y-axis limits slightly for better spacing
|
| 148 |
+
y_min, y_max = ax.get_ylim()
|
| 149 |
+
ax.set_ylim(y_min - 0.01, y_max + 0.02)
|
| 150 |
+
|
| 151 |
+
# Keep plot clean and simple
|
| 152 |
+
|
| 153 |
+
# Remove figure caption
|
| 154 |
+
|
| 155 |
+
# Save with high quality
|
| 156 |
+
plt.savefig("plot.png", dpi=300, bbox_inches='tight', facecolor='white')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
plt.show()
|