thewh1teagle
Refactor code structure for improved readability and maintainability
d6ee776 unverified
Raw
History Blame Contribute Delete
4.72 kB
"""
uv pip install matplotlib
"""
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.ticker as ticker
# Data for the models
models = [
("Piper", 0.11, 0.09, "Ours"),
("StyleTTS2", 0.07, 0.50, "Ours"),
("HebTTS", 0.24, 25.44, "Open"),
("LoTHM", 0.49, 84.75, "Open"),
("MMS", 0.20, 0.21, "Open"),
("SASPEECH", 0.11, 0.16, "Open"),
("Robo-Shaul", 0.08, 1.58, "Open"),
("Google", 0.04, 4.08, "Proprietary"),
("OpenAI", 0.05, 1.60, "Proprietary"),
]
# Filter out models with None values for WER or RTF
filtered = [m for m in models if m[1] is not None and m[2] is not None]
# Create the figure and axes with better sizing
fig, ax = plt.subplots(figsize=(12, 8))
# Color mapping with fancy colors
colors = {'Ours': '#e74c3c', 'Open': '#3498db', 'Proprietary': '#f39c12'}
legend_elements = []
# Plot each model
for name, wer, rtf, category in filtered:
# Determine color based on category
color = colors[category]
# Determine size and weight for our models
size = 180 # Same size for all models
weight = 'bold' if category == 'Ours' else 'normal'
edgewidth = 2 if category == 'Ours' else 1.5
# Create label for the point
label = f"Ours ({name})" if category == 'Ours' else name
# Plot the scatter point
scatter = ax.scatter(rtf, wer, s=size, c=color, edgecolors='black',
linewidths=edgewidth, zorder=3, alpha=0.8)
# Adjust text position for each model
if name == "HebTTS":
x_text = rtf * 0.75
y_text = wer
ha = 'right'
va = 'center'
elif name == "Google":
x_text = rtf * 1.2
y_text = wer - 0.008
ha = 'left'
va = 'center'
elif name == 'LoTHM':
x_text = rtf * 0.85
y_text = wer
ha = 'right'
va = 'center'
elif name == "OpenAI":
x_text = rtf
y_text = wer - 0.012
ha = 'center'
va = 'top'
elif name == "Robo-Shaul":
x_text = rtf * 1.2
y_text = wer + 0.008
ha = 'left'
va = 'center'
elif name == "Piper":
x_text = rtf * 0.9
y_text = wer - 0.022
ha = 'left'
va = 'top'
elif name == "StyleTTS2":
x_text = rtf * 0.8
y_text = wer - 0.018
ha = 'center'
va = 'top'
elif name == "SASPEECH":
x_text = rtf * 1.2
y_text = wer + 0.008
ha = 'left'
va = 'center'
elif name == "MMS":
x_text = rtf * 0.85
y_text = wer + 0.015
ha = 'right'
va = 'bottom'
else:
x_text = rtf * 1.15
y_text = wer
ha = 'left'
va = 'center'
# Add text label for each point
fontsize = 20 if category == 'Ours' else 22
ax.text(x_text, y_text, label, fontsize=fontsize, ha=ha, va=va,
color='black', weight=weight, zorder=4)
# Set x-axis to log scale and format it
ax.set_xscale('log')
ax.tick_params(axis='both', which='major', labelsize=14)
ax.tick_params(axis='both', which='minor', labelsize=12)
ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
ax.xaxis.get_major_formatter().set_scientific(False)
ax.xaxis.get_major_formatter().set_useOffset(False)
# Add minor ticks for better readability
ax.xaxis.set_minor_locator(ticker.LogLocator(base=10.0, subs=np.arange(2, 10) * 0.1))
# Set axis labels with larger font
ax.set_xlabel("← RTF (Faster)", fontsize=28, fontweight='bold')
ax.set_ylabel("← WER (Precise)", fontsize=28, fontweight='bold')
# Remove title
# Add subtle grid
ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
ax.set_axisbelow(True)
# Create custom legend
# from matplotlib.patches import Patch
# legend_elements = [
# plt.scatter([], [], c=colors['Ours'], s=150, edgecolors='black',
# linewidths=1.5, label='Our Models'),
# plt.scatter([], [], c=colors['Open'], s=120, edgecolors='black',
# linewidths=1.5, label='Open Source'),
# plt.scatter([], [], c=colors['Proprietary'], s=120, edgecolors='black',
# linewidths=1.5, label='Proprietary')
# ]
# ax.legend(handles=legend_elements, loc='upper right', fontsize=14,
# frameon=True, fancybox=True, shadow=True)
# Adjust layout to prevent labels from being cut off
plt.tight_layout()
# Extend x-axis limits by 30% to make space for labels
x_min, x_max = ax.get_xlim()
ax.set_xlim(x_min, x_max * 1.3)
# Extend y-axis limits slightly for better spacing
y_min, y_max = ax.get_ylim()
ax.set_ylim(y_min - 0.01, y_max + 0.02)
# Keep plot clean and simple
# Remove figure caption
# Save with high quality
plt.savefig("plot.png", dpi=300, bbox_inches='tight', facecolor='white')
plt.show()