thewh1teagle
Refactor code structure for improved readability and maintainability
d6ee776 unverified
"""
uv pip install matplotlib
"""
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.ticker as ticker
# Data for the models
models = [
("Piper", 0.11, 0.09, "Ours"),
("StyleTTS2", 0.07, 0.50, "Ours"),
("HebTTS", 0.24, 25.44, "Open"),
("LoTHM", 0.49, 84.75, "Open"),
("MMS", 0.20, 0.21, "Open"),
("SASPEECH", 0.11, 0.16, "Open"),
("Robo-Shaul", 0.08, 1.58, "Open"),
("Google", 0.04, 4.08, "Proprietary"),
("OpenAI", 0.05, 1.60, "Proprietary"),
]
# Filter out models with None values for WER or RTF
filtered = [m for m in models if m[1] is not None and m[2] is not None]
# Create the figure and axes with better sizing
fig, ax = plt.subplots(figsize=(12, 8))
# Color mapping with fancy colors
colors = {'Ours': '#e74c3c', 'Open': '#3498db', 'Proprietary': '#f39c12'}
legend_elements = []
# Plot each model
for name, wer, rtf, category in filtered:
# Determine color based on category
color = colors[category]
# Determine size and weight for our models
size = 180 # Same size for all models
weight = 'bold' if category == 'Ours' else 'normal'
edgewidth = 2 if category == 'Ours' else 1.5
# Create label for the point
label = f"Ours ({name})" if category == 'Ours' else name
# Plot the scatter point
scatter = ax.scatter(rtf, wer, s=size, c=color, edgecolors='black',
linewidths=edgewidth, zorder=3, alpha=0.8)
# Adjust text position for each model
if name == "HebTTS":
x_text = rtf * 0.75
y_text = wer
ha = 'right'
va = 'center'
elif name == "Google":
x_text = rtf * 1.2
y_text = wer - 0.008
ha = 'left'
va = 'center'
elif name == 'LoTHM':
x_text = rtf * 0.85
y_text = wer
ha = 'right'
va = 'center'
elif name == "OpenAI":
x_text = rtf
y_text = wer - 0.012
ha = 'center'
va = 'top'
elif name == "Robo-Shaul":
x_text = rtf * 1.2
y_text = wer + 0.008
ha = 'left'
va = 'center'
elif name == "Piper":
x_text = rtf * 0.9
y_text = wer - 0.022
ha = 'left'
va = 'top'
elif name == "StyleTTS2":
x_text = rtf * 0.8
y_text = wer - 0.018
ha = 'center'
va = 'top'
elif name == "SASPEECH":
x_text = rtf * 1.2
y_text = wer + 0.008
ha = 'left'
va = 'center'
elif name == "MMS":
x_text = rtf * 0.85
y_text = wer + 0.015
ha = 'right'
va = 'bottom'
else:
x_text = rtf * 1.15
y_text = wer
ha = 'left'
va = 'center'
# Add text label for each point
fontsize = 20 if category == 'Ours' else 22
ax.text(x_text, y_text, label, fontsize=fontsize, ha=ha, va=va,
color='black', weight=weight, zorder=4)
# Set x-axis to log scale and format it
ax.set_xscale('log')
ax.tick_params(axis='both', which='major', labelsize=14)
ax.tick_params(axis='both', which='minor', labelsize=12)
ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
ax.xaxis.get_major_formatter().set_scientific(False)
ax.xaxis.get_major_formatter().set_useOffset(False)
# Add minor ticks for better readability
ax.xaxis.set_minor_locator(ticker.LogLocator(base=10.0, subs=np.arange(2, 10) * 0.1))
# Set axis labels with larger font
ax.set_xlabel("← RTF (Faster)", fontsize=28, fontweight='bold')
ax.set_ylabel("← WER (Precise)", fontsize=28, fontweight='bold')
# Remove title
# Add subtle grid
ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
ax.set_axisbelow(True)
# Create custom legend
# from matplotlib.patches import Patch
# legend_elements = [
# plt.scatter([], [], c=colors['Ours'], s=150, edgecolors='black',
# linewidths=1.5, label='Our Models'),
# plt.scatter([], [], c=colors['Open'], s=120, edgecolors='black',
# linewidths=1.5, label='Open Source'),
# plt.scatter([], [], c=colors['Proprietary'], s=120, edgecolors='black',
# linewidths=1.5, label='Proprietary')
# ]
# ax.legend(handles=legend_elements, loc='upper right', fontsize=14,
# frameon=True, fancybox=True, shadow=True)
# Adjust layout to prevent labels from being cut off
plt.tight_layout()
# Extend x-axis limits by 30% to make space for labels
x_min, x_max = ax.get_xlim()
ax.set_xlim(x_min, x_max * 1.3)
# Extend y-axis limits slightly for better spacing
y_min, y_max = ax.get_ylim()
ax.set_ylim(y_min - 0.01, y_max + 0.02)
# Keep plot clean and simple
# Remove figure caption
# Save with high quality
plt.savefig("plot.png", dpi=300, bbox_inches='tight', facecolor='white')
plt.show()