File size: 4,721 Bytes
d6383e7
 
 
 
 
 
 
 
 
475a735
 
 
 
 
 
 
 
 
d6383e7
 
 
 
 
475a735
 
 
 
 
 
d6383e7
 
 
 
475a735
 
d6383e7
475a735
d6383e7
475a735
d6383e7
 
 
 
 
475a735
 
d6383e7
475a735
d6383e7
475a735
 
d6383e7
475a735
d6383e7
475a735
 
d6383e7
475a735
d6383e7
475a735
 
 
 
d6383e7
475a735
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6383e7
475a735
d6383e7
 
475a735
d6383e7
475a735
d6383e7
 
475a735
 
 
d6383e7
 
 
475a735
 
d6383e7
 
 
 
475a735
 
d6383e7
475a735
d6ee776
 
180f090
475a735
 
 
 
 
 
 
 
d6ee776
 
 
 
 
 
 
 
 
 
 
 
d6383e7
 
 
 
475a735
d6383e7
475a735
 
 
 
 
 
 
 
 
 
 
 
d6383e7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""
uv pip install matplotlib
"""
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.ticker as ticker

# Data for the models
models = [
    ("Piper", 0.11, 0.09, "Ours"),
    ("StyleTTS2", 0.07, 0.50, "Ours"),
    ("HebTTS", 0.24, 25.44, "Open"),
    ("LoTHM", 0.49, 84.75, "Open"),
    ("MMS", 0.20, 0.21, "Open"),
    ("SASPEECH", 0.11, 0.16, "Open"),
    ("Robo-Shaul", 0.08, 1.58, "Open"),
    ("Google", 0.04, 4.08, "Proprietary"),
    ("OpenAI", 0.05, 1.60, "Proprietary"),
]

# Filter out models with None values for WER or RTF
filtered = [m for m in models if m[1] is not None and m[2] is not None]

# Create the figure and axes with better sizing
fig, ax = plt.subplots(figsize=(12, 8))

# Color mapping with fancy colors
colors = {'Ours': '#e74c3c', 'Open': '#3498db', 'Proprietary': '#f39c12'}
legend_elements = []

# Plot each model
for name, wer, rtf, category in filtered:
    # Determine color based on category
    color = colors[category]
    
    # Determine size and weight for our models
    size = 180  # Same size for all models
    weight = 'bold' if category == 'Ours' else 'normal'
    edgewidth = 2 if category == 'Ours' else 1.5

    # Create label for the point
    label = f"Ours ({name})" if category == 'Ours' else name

    # Plot the scatter point
    scatter = ax.scatter(rtf, wer, s=size, c=color, edgecolors='black', 
                        linewidths=edgewidth, zorder=3, alpha=0.8)

    # Adjust text position for each model
    if name == "HebTTS":
        x_text = rtf * 0.75
        y_text = wer
        ha = 'right'
        va = 'center'
    elif name == "Google":
        x_text = rtf * 1.2
        y_text = wer - 0.008
        ha = 'left'
        va = 'center'
    elif name == 'LoTHM':
        x_text = rtf * 0.85
        y_text = wer
        ha = 'right'
        va = 'center'
    elif name == "OpenAI":
        x_text = rtf
        y_text = wer - 0.012
        ha = 'center'
        va = 'top'
    elif name == "Robo-Shaul":
        x_text = rtf * 1.2
        y_text = wer + 0.008
        ha = 'left'
        va = 'center'
    elif name == "Piper":
        x_text = rtf * 0.9
        y_text = wer - 0.022
        ha = 'left'
        va = 'top'
    elif name == "StyleTTS2":
        x_text = rtf * 0.8
        y_text = wer - 0.018
        ha = 'center'
        va = 'top'
    elif name == "SASPEECH":
        x_text = rtf * 1.2
        y_text = wer + 0.008
        ha = 'left'
        va = 'center'
    elif name == "MMS":
        x_text = rtf * 0.85
        y_text = wer + 0.015
        ha = 'right'
        va = 'bottom'
    else:
        x_text = rtf * 1.15
        y_text = wer
        ha = 'left'
        va = 'center'

    # Add text label for each point
    fontsize = 20 if category == 'Ours' else 22
    ax.text(x_text, y_text, label, fontsize=fontsize, ha=ha, va=va, 
            color='black', weight=weight, zorder=4)

# Set x-axis to log scale and format it
ax.set_xscale('log')
ax.tick_params(axis='both', which='major', labelsize=14)
ax.tick_params(axis='both', which='minor', labelsize=12)
ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
ax.xaxis.get_major_formatter().set_scientific(False)
ax.xaxis.get_major_formatter().set_useOffset(False)

# Add minor ticks for better readability
ax.xaxis.set_minor_locator(ticker.LogLocator(base=10.0, subs=np.arange(2, 10) * 0.1))

# Set axis labels with larger font
ax.set_xlabel("← RTF (Faster)", fontsize=28, fontweight='bold')
ax.set_ylabel("← WER (Precise)", fontsize=28, fontweight='bold')


# Remove title

# Add subtle grid
ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
ax.set_axisbelow(True)

# Create custom legend
# from matplotlib.patches import Patch
# legend_elements = [
#     plt.scatter([], [], c=colors['Ours'], s=150, edgecolors='black', 
#                 linewidths=1.5, label='Our Models'),
#     plt.scatter([], [], c=colors['Open'], s=120, edgecolors='black', 
#                 linewidths=1.5, label='Open Source'),
#     plt.scatter([], [], c=colors['Proprietary'], s=120, edgecolors='black', 
#                 linewidths=1.5, label='Proprietary')
# ]

# ax.legend(handles=legend_elements, loc='upper right', fontsize=14, 
#           frameon=True, fancybox=True, shadow=True)

# Adjust layout to prevent labels from being cut off
plt.tight_layout()

# Extend x-axis limits by 30% to make space for labels
x_min, x_max = ax.get_xlim()
ax.set_xlim(x_min, x_max * 1.3)

# Extend y-axis limits slightly for better spacing
y_min, y_max = ax.get_ylim()
ax.set_ylim(y_min - 0.01, y_max + 0.02)

# Keep plot clean and simple

# Remove figure caption

# Save with high quality
plt.savefig("plot.png", dpi=300, bbox_inches='tight', facecolor='white')
plt.show()