""" uv pip install matplotlib """ import matplotlib.pyplot as plt import numpy as np import matplotlib.ticker as ticker # Data for the models models = [ ("Piper", 0.11, 0.09, "Ours"), ("StyleTTS2", 0.07, 0.50, "Ours"), ("HebTTS", 0.24, 25.44, "Open"), ("LoTHM", 0.49, 84.75, "Open"), ("MMS", 0.20, 0.21, "Open"), ("SASPEECH", 0.11, 0.16, "Open"), ("Robo-Shaul", 0.08, 1.58, "Open"), ("Google", 0.04, 4.08, "Proprietary"), ("OpenAI", 0.05, 1.60, "Proprietary"), ] # Filter out models with None values for WER or RTF filtered = [m for m in models if m[1] is not None and m[2] is not None] # Create the figure and axes with better sizing fig, ax = plt.subplots(figsize=(12, 8)) # Color mapping with fancy colors colors = {'Ours': '#e74c3c', 'Open': '#3498db', 'Proprietary': '#f39c12'} legend_elements = [] # Plot each model for name, wer, rtf, category in filtered: # Determine color based on category color = colors[category] # Determine size and weight for our models size = 180 # Same size for all models weight = 'bold' if category == 'Ours' else 'normal' edgewidth = 2 if category == 'Ours' else 1.5 # Create label for the point label = f"Ours ({name})" if category == 'Ours' else name # Plot the scatter point scatter = ax.scatter(rtf, wer, s=size, c=color, edgecolors='black', linewidths=edgewidth, zorder=3, alpha=0.8) # Adjust text position for each model if name == "HebTTS": x_text = rtf * 0.75 y_text = wer ha = 'right' va = 'center' elif name == "Google": x_text = rtf * 1.2 y_text = wer - 0.008 ha = 'left' va = 'center' elif name == 'LoTHM': x_text = rtf * 0.85 y_text = wer ha = 'right' va = 'center' elif name == "OpenAI": x_text = rtf y_text = wer - 0.012 ha = 'center' va = 'top' elif name == "Robo-Shaul": x_text = rtf * 1.2 y_text = wer + 0.008 ha = 'left' va = 'center' elif name == "Piper": x_text = rtf * 0.9 y_text = wer - 0.022 ha = 'left' va = 'top' elif name == "StyleTTS2": x_text = rtf * 0.8 y_text = wer - 0.018 ha = 'center' va = 'top' elif name == "SASPEECH": x_text = rtf * 1.2 y_text = wer + 0.008 ha = 'left' va = 'center' elif name == "MMS": x_text = rtf * 0.85 y_text = wer + 0.015 ha = 'right' va = 'bottom' else: x_text = rtf * 1.15 y_text = wer ha = 'left' va = 'center' # Add text label for each point fontsize = 20 if category == 'Ours' else 22 ax.text(x_text, y_text, label, fontsize=fontsize, ha=ha, va=va, color='black', weight=weight, zorder=4) # Set x-axis to log scale and format it ax.set_xscale('log') ax.tick_params(axis='both', which='major', labelsize=14) ax.tick_params(axis='both', which='minor', labelsize=12) ax.xaxis.set_major_formatter(ticker.ScalarFormatter()) ax.xaxis.get_major_formatter().set_scientific(False) ax.xaxis.get_major_formatter().set_useOffset(False) # Add minor ticks for better readability ax.xaxis.set_minor_locator(ticker.LogLocator(base=10.0, subs=np.arange(2, 10) * 0.1)) # Set axis labels with larger font ax.set_xlabel("← RTF (Faster)", fontsize=28, fontweight='bold') ax.set_ylabel("← WER (Precise)", fontsize=28, fontweight='bold') # Remove title # Add subtle grid ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5) ax.set_axisbelow(True) # Create custom legend # from matplotlib.patches import Patch # legend_elements = [ # plt.scatter([], [], c=colors['Ours'], s=150, edgecolors='black', # linewidths=1.5, label='Our Models'), # plt.scatter([], [], c=colors['Open'], s=120, edgecolors='black', # linewidths=1.5, label='Open Source'), # plt.scatter([], [], c=colors['Proprietary'], s=120, edgecolors='black', # linewidths=1.5, label='Proprietary') # ] # ax.legend(handles=legend_elements, loc='upper right', fontsize=14, # frameon=True, fancybox=True, shadow=True) # Adjust layout to prevent labels from being cut off plt.tight_layout() # Extend x-axis limits by 30% to make space for labels x_min, x_max = ax.get_xlim() ax.set_xlim(x_min, x_max * 1.3) # Extend y-axis limits slightly for better spacing y_min, y_max = ax.get_ylim() ax.set_ylim(y_min - 0.01, y_max + 0.02) # Keep plot clean and simple # Remove figure caption # Save with high quality plt.savefig("plot.png", dpi=300, bbox_inches='tight', facecolor='white') plt.show()