File size: 4,721 Bytes
d6383e7 475a735 d6383e7 475a735 d6383e7 475a735 d6383e7 475a735 d6383e7 475a735 d6383e7 475a735 d6383e7 475a735 d6383e7 475a735 d6383e7 475a735 d6383e7 475a735 d6383e7 475a735 d6383e7 475a735 d6383e7 475a735 d6383e7 475a735 d6383e7 475a735 d6383e7 475a735 d6383e7 475a735 d6383e7 475a735 d6383e7 475a735 d6383e7 475a735 d6ee776 180f090 475a735 d6ee776 d6383e7 475a735 d6383e7 475a735 d6383e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
"""
uv pip install matplotlib
"""
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.ticker as ticker
# Data for the models
models = [
("Piper", 0.11, 0.09, "Ours"),
("StyleTTS2", 0.07, 0.50, "Ours"),
("HebTTS", 0.24, 25.44, "Open"),
("LoTHM", 0.49, 84.75, "Open"),
("MMS", 0.20, 0.21, "Open"),
("SASPEECH", 0.11, 0.16, "Open"),
("Robo-Shaul", 0.08, 1.58, "Open"),
("Google", 0.04, 4.08, "Proprietary"),
("OpenAI", 0.05, 1.60, "Proprietary"),
]
# Filter out models with None values for WER or RTF
filtered = [m for m in models if m[1] is not None and m[2] is not None]
# Create the figure and axes with better sizing
fig, ax = plt.subplots(figsize=(12, 8))
# Color mapping with fancy colors
colors = {'Ours': '#e74c3c', 'Open': '#3498db', 'Proprietary': '#f39c12'}
legend_elements = []
# Plot each model
for name, wer, rtf, category in filtered:
# Determine color based on category
color = colors[category]
# Determine size and weight for our models
size = 180 # Same size for all models
weight = 'bold' if category == 'Ours' else 'normal'
edgewidth = 2 if category == 'Ours' else 1.5
# Create label for the point
label = f"Ours ({name})" if category == 'Ours' else name
# Plot the scatter point
scatter = ax.scatter(rtf, wer, s=size, c=color, edgecolors='black',
linewidths=edgewidth, zorder=3, alpha=0.8)
# Adjust text position for each model
if name == "HebTTS":
x_text = rtf * 0.75
y_text = wer
ha = 'right'
va = 'center'
elif name == "Google":
x_text = rtf * 1.2
y_text = wer - 0.008
ha = 'left'
va = 'center'
elif name == 'LoTHM':
x_text = rtf * 0.85
y_text = wer
ha = 'right'
va = 'center'
elif name == "OpenAI":
x_text = rtf
y_text = wer - 0.012
ha = 'center'
va = 'top'
elif name == "Robo-Shaul":
x_text = rtf * 1.2
y_text = wer + 0.008
ha = 'left'
va = 'center'
elif name == "Piper":
x_text = rtf * 0.9
y_text = wer - 0.022
ha = 'left'
va = 'top'
elif name == "StyleTTS2":
x_text = rtf * 0.8
y_text = wer - 0.018
ha = 'center'
va = 'top'
elif name == "SASPEECH":
x_text = rtf * 1.2
y_text = wer + 0.008
ha = 'left'
va = 'center'
elif name == "MMS":
x_text = rtf * 0.85
y_text = wer + 0.015
ha = 'right'
va = 'bottom'
else:
x_text = rtf * 1.15
y_text = wer
ha = 'left'
va = 'center'
# Add text label for each point
fontsize = 20 if category == 'Ours' else 22
ax.text(x_text, y_text, label, fontsize=fontsize, ha=ha, va=va,
color='black', weight=weight, zorder=4)
# Set x-axis to log scale and format it
ax.set_xscale('log')
ax.tick_params(axis='both', which='major', labelsize=14)
ax.tick_params(axis='both', which='minor', labelsize=12)
ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
ax.xaxis.get_major_formatter().set_scientific(False)
ax.xaxis.get_major_formatter().set_useOffset(False)
# Add minor ticks for better readability
ax.xaxis.set_minor_locator(ticker.LogLocator(base=10.0, subs=np.arange(2, 10) * 0.1))
# Set axis labels with larger font
ax.set_xlabel("← RTF (Faster)", fontsize=28, fontweight='bold')
ax.set_ylabel("← WER (Precise)", fontsize=28, fontweight='bold')
# Remove title
# Add subtle grid
ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
ax.set_axisbelow(True)
# Create custom legend
# from matplotlib.patches import Patch
# legend_elements = [
# plt.scatter([], [], c=colors['Ours'], s=150, edgecolors='black',
# linewidths=1.5, label='Our Models'),
# plt.scatter([], [], c=colors['Open'], s=120, edgecolors='black',
# linewidths=1.5, label='Open Source'),
# plt.scatter([], [], c=colors['Proprietary'], s=120, edgecolors='black',
# linewidths=1.5, label='Proprietary')
# ]
# ax.legend(handles=legend_elements, loc='upper right', fontsize=14,
# frameon=True, fancybox=True, shadow=True)
# Adjust layout to prevent labels from being cut off
plt.tight_layout()
# Extend x-axis limits by 30% to make space for labels
x_min, x_max = ax.get_xlim()
ax.set_xlim(x_min, x_max * 1.3)
# Extend y-axis limits slightly for better spacing
y_min, y_max = ax.get_ylim()
ax.set_ylim(y_min - 0.01, y_max + 0.02)
# Keep plot clean and simple
# Remove figure caption
# Save with high quality
plt.savefig("plot.png", dpi=300, bbox_inches='tight', facecolor='white')
plt.show() |