thewh1teagle
Refactor code structure for improved readability and maintainability
d6ee776
unverified
| """ | |
| uv pip install matplotlib | |
| """ | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import matplotlib.ticker as ticker | |
| # Data for the models | |
| models = [ | |
| ("Piper", 0.11, 0.09, "Ours"), | |
| ("StyleTTS2", 0.07, 0.50, "Ours"), | |
| ("HebTTS", 0.24, 25.44, "Open"), | |
| ("LoTHM", 0.49, 84.75, "Open"), | |
| ("MMS", 0.20, 0.21, "Open"), | |
| ("SASPEECH", 0.11, 0.16, "Open"), | |
| ("Robo-Shaul", 0.08, 1.58, "Open"), | |
| ("Google", 0.04, 4.08, "Proprietary"), | |
| ("OpenAI", 0.05, 1.60, "Proprietary"), | |
| ] | |
| # Filter out models with None values for WER or RTF | |
| filtered = [m for m in models if m[1] is not None and m[2] is not None] | |
| # Create the figure and axes with better sizing | |
| fig, ax = plt.subplots(figsize=(12, 8)) | |
| # Color mapping with fancy colors | |
| colors = {'Ours': '#e74c3c', 'Open': '#3498db', 'Proprietary': '#f39c12'} | |
| legend_elements = [] | |
| # Plot each model | |
| for name, wer, rtf, category in filtered: | |
| # Determine color based on category | |
| color = colors[category] | |
| # Determine size and weight for our models | |
| size = 180 # Same size for all models | |
| weight = 'bold' if category == 'Ours' else 'normal' | |
| edgewidth = 2 if category == 'Ours' else 1.5 | |
| # Create label for the point | |
| label = f"Ours ({name})" if category == 'Ours' else name | |
| # Plot the scatter point | |
| scatter = ax.scatter(rtf, wer, s=size, c=color, edgecolors='black', | |
| linewidths=edgewidth, zorder=3, alpha=0.8) | |
| # Adjust text position for each model | |
| if name == "HebTTS": | |
| x_text = rtf * 0.75 | |
| y_text = wer | |
| ha = 'right' | |
| va = 'center' | |
| elif name == "Google": | |
| x_text = rtf * 1.2 | |
| y_text = wer - 0.008 | |
| ha = 'left' | |
| va = 'center' | |
| elif name == 'LoTHM': | |
| x_text = rtf * 0.85 | |
| y_text = wer | |
| ha = 'right' | |
| va = 'center' | |
| elif name == "OpenAI": | |
| x_text = rtf | |
| y_text = wer - 0.012 | |
| ha = 'center' | |
| va = 'top' | |
| elif name == "Robo-Shaul": | |
| x_text = rtf * 1.2 | |
| y_text = wer + 0.008 | |
| ha = 'left' | |
| va = 'center' | |
| elif name == "Piper": | |
| x_text = rtf * 0.9 | |
| y_text = wer - 0.022 | |
| ha = 'left' | |
| va = 'top' | |
| elif name == "StyleTTS2": | |
| x_text = rtf * 0.8 | |
| y_text = wer - 0.018 | |
| ha = 'center' | |
| va = 'top' | |
| elif name == "SASPEECH": | |
| x_text = rtf * 1.2 | |
| y_text = wer + 0.008 | |
| ha = 'left' | |
| va = 'center' | |
| elif name == "MMS": | |
| x_text = rtf * 0.85 | |
| y_text = wer + 0.015 | |
| ha = 'right' | |
| va = 'bottom' | |
| else: | |
| x_text = rtf * 1.15 | |
| y_text = wer | |
| ha = 'left' | |
| va = 'center' | |
| # Add text label for each point | |
| fontsize = 20 if category == 'Ours' else 22 | |
| ax.text(x_text, y_text, label, fontsize=fontsize, ha=ha, va=va, | |
| color='black', weight=weight, zorder=4) | |
| # Set x-axis to log scale and format it | |
| ax.set_xscale('log') | |
| ax.tick_params(axis='both', which='major', labelsize=14) | |
| ax.tick_params(axis='both', which='minor', labelsize=12) | |
| ax.xaxis.set_major_formatter(ticker.ScalarFormatter()) | |
| ax.xaxis.get_major_formatter().set_scientific(False) | |
| ax.xaxis.get_major_formatter().set_useOffset(False) | |
| # Add minor ticks for better readability | |
| ax.xaxis.set_minor_locator(ticker.LogLocator(base=10.0, subs=np.arange(2, 10) * 0.1)) | |
| # Set axis labels with larger font | |
| ax.set_xlabel("← RTF (Faster)", fontsize=28, fontweight='bold') | |
| ax.set_ylabel("← WER (Precise)", fontsize=28, fontweight='bold') | |
| # Remove title | |
| # Add subtle grid | |
| ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5) | |
| ax.set_axisbelow(True) | |
| # Create custom legend | |
| # from matplotlib.patches import Patch | |
| # legend_elements = [ | |
| # plt.scatter([], [], c=colors['Ours'], s=150, edgecolors='black', | |
| # linewidths=1.5, label='Our Models'), | |
| # plt.scatter([], [], c=colors['Open'], s=120, edgecolors='black', | |
| # linewidths=1.5, label='Open Source'), | |
| # plt.scatter([], [], c=colors['Proprietary'], s=120, edgecolors='black', | |
| # linewidths=1.5, label='Proprietary') | |
| # ] | |
| # ax.legend(handles=legend_elements, loc='upper right', fontsize=14, | |
| # frameon=True, fancybox=True, shadow=True) | |
| # Adjust layout to prevent labels from being cut off | |
| plt.tight_layout() | |
| # Extend x-axis limits by 30% to make space for labels | |
| x_min, x_max = ax.get_xlim() | |
| ax.set_xlim(x_min, x_max * 1.3) | |
| # Extend y-axis limits slightly for better spacing | |
| y_min, y_max = ax.get_ylim() | |
| ax.set_ylim(y_min - 0.01, y_max + 0.02) | |
| # Keep plot clean and simple | |
| # Remove figure caption | |
| # Save with high quality | |
| plt.savefig("plot.png", dpi=300, bbox_inches='tight', facecolor='white') | |
| plt.show() |