thewh1teagle commited on
Commit
475a735
·
unverified ·
1 Parent(s): 6c7f537

Update WER plot with improved model metrics and visualization enhancements

Browse files
Files changed (1) hide show
  1. comparison/wer_plot.py +105 -64
comparison/wer_plot.py CHANGED
@@ -7,110 +7,151 @@ import matplotlib.ticker as ticker
7
 
8
  # Data for the models
9
  models = [
10
- ("Piper", 0.22, 0.09, "Ours"),
11
- ("StyleTTS2", 0.18, 0.50, "Ours"),
12
- ("HebTTS", 0.36, 25.44, "Open"),
13
- ("LoTHM", 0.59, 84.75, "Open"),
14
- ("MMS", 0.33, 0.21, "Open"),
15
- ("SASPEECH", 0.22, 0.16, "Open"),
16
- ("Robo-Shaul", 0.20, 1.58, "Open"),
17
- ("Google", 0.20, 4.08, "Proprietary"),
18
- ("OpenAI", 0.21, 1.60, "Proprietary"),
19
  ]
20
 
21
  # Filter out models with None values for WER or RTF
22
  filtered = [m for m in models if m[1] is not None and m[2] is not None]
23
 
24
- # Create the figure and axes
25
- fig, ax = plt.subplots(figsize=(10, 6))
 
 
 
 
26
 
27
  # Plot each model
28
  for name, wer, rtf, category in filtered:
29
  # Determine color based on category
30
- if name in ["Google", "OpenAI"]:
31
- color = '#f4c285'
32
- elif category == 'Ours':
33
- color = 'red'
34
- else:
35
- color = 'blue'
36
-
37
  # Determine size and weight for our models
38
- size = 200 if category == 'Ours' else 100
39
  weight = 'bold' if category == 'Ours' else 'normal'
40
- weight = 'normal' if name in ["Google", "OpenAI"] else weight
41
 
42
  # Create label for the point
43
  label = f"Ours ({name})" if category == 'Ours' else name
44
 
45
  # Plot the scatter point
46
- ax.scatter(rtf, wer, s=size, c=color, edgecolors='black', linewidths=1, zorder=3) # zorder to ensure points are above grid
 
47
 
48
- # Adjust text position for HebTTS to avoid overlap
49
  if name == "HebTTS":
50
- x_text = rtf * 0.85
 
51
  ha = 'right'
 
52
  elif name == "Google":
53
- x_text = rtf * 1.15 # right
54
- y_text = wer - 0.01 # slightly down
55
  ha = 'left'
 
56
  elif name == 'LoTHM':
57
- x_text = rtf * 0.90
 
 
 
58
  elif name == "OpenAI":
59
- x_text = rtf * 0.85 # left
60
- y_text = wer - 0.01 # slightly down
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  ha = 'right'
 
62
  else:
63
  x_text = rtf * 1.15
 
64
  ha = 'left'
 
65
 
66
  # Add text label for each point
67
- ax.text(x_text, wer, label, fontsize=22, ha=ha, va='center', color='black', weight=weight, zorder=4)
 
 
68
 
69
  # Set x-axis to log scale and format it
70
  ax.set_xscale('log')
71
- ax.tick_params(axis='both', which='major', labelsize=16)
 
72
  ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
73
  ax.xaxis.get_major_formatter().set_scientific(False)
74
  ax.xaxis.get_major_formatter().set_useOffset(False)
75
 
76
- # Set axis labels
77
- ax.set_xlabel("RTF (lower is faster)", fontsize=22)
78
- ax.set_ylabel("WER (lower is more accurate)", fontsize=22)
79
 
80
- # Remove grid lines
81
- ax.grid(False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  # Adjust layout to prevent labels from being cut off
84
  plt.tight_layout()
85
 
86
- # Extend x-axis limits by 20% to make space for labels/arrows
87
  x_min, x_max = ax.get_xlim()
88
- ax.set_xlim(x_min, x_max * 1.2)
89
-
90
- # --- Add Arrows for Direction of Improvement ---
91
-
92
- # Get current axis limits to position arrows relative to the plot
93
- x_lims = ax.get_xlim()
94
- y_lims = ax.get_ylim()
95
-
96
- # Position arrow in upper-right area of plot, pointing to bottom-left
97
- # Start point (upper-right area)
98
- arrow_start_x = x_lims[1] * 0.002 # 80% across the x-axis
99
- arrow_start_y = y_lims[1] * 0.55 # 85% up the y-axis
100
-
101
- # End point (much closer to create shorter arrow with steeper angle)
102
- arrow_end_x = x_lims[1] * 0.0006 # 60% across the x-axis (shorter horizontal distance)
103
- arrow_end_y = y_lims[1] * 0.43 # 60% up the y-axis (steeper vertical drop)
104
-
105
- # Draw arrow pointing from upper-right toward bottom-left
106
- ax.annotate('',
107
- xy=(arrow_end_x, arrow_end_y), # End point (arrow head)
108
- xytext=(arrow_start_x, arrow_start_y), # Start point (arrow tail)
109
- arrowprops=dict(facecolor='gray', shrink=0.05, width=0.5, headwidth=8, alpha=0.3),
110
- annotation_clip=False,
111
- zorder=1) # Behind circles (3) and text (4)
112
-
113
- # Clear any existing title and save the figure
114
- plt.title("")
115
- plt.savefig("plot.png", dpi=1200)
116
  plt.show()
 
7
 
8
  # Data for the models
9
  models = [
10
+ ("Piper", 0.11, 0.09, "Ours"),
11
+ ("StyleTTS2", 0.07, 0.50, "Ours"),
12
+ ("HebTTS", 0.24, 25.44, "Open"),
13
+ ("LoTHM", 0.49, 84.75, "Open"),
14
+ ("MMS", 0.20, 0.21, "Open"),
15
+ ("SASPEECH", 0.11, 0.16, "Open"),
16
+ ("Robo-Shaul", 0.08, 1.58, "Open"),
17
+ ("Google", 0.04, 4.08, "Proprietary"),
18
+ ("OpenAI", 0.05, 1.60, "Proprietary"),
19
  ]
20
 
21
  # Filter out models with None values for WER or RTF
22
  filtered = [m for m in models if m[1] is not None and m[2] is not None]
23
 
24
+ # Create the figure and axes with better sizing
25
+ fig, ax = plt.subplots(figsize=(12, 8))
26
+
27
+ # Color mapping with fancy colors
28
+ colors = {'Ours': '#e74c3c', 'Open': '#3498db', 'Proprietary': '#f39c12'}
29
+ legend_elements = []
30
 
31
  # Plot each model
32
  for name, wer, rtf, category in filtered:
33
  # Determine color based on category
34
+ color = colors[category]
35
+
 
 
 
 
 
36
  # Determine size and weight for our models
37
+ size = 180 # Same size for all models
38
  weight = 'bold' if category == 'Ours' else 'normal'
39
+ edgewidth = 2 if category == 'Ours' else 1.5
40
 
41
  # Create label for the point
42
  label = f"Ours ({name})" if category == 'Ours' else name
43
 
44
  # Plot the scatter point
45
+ scatter = ax.scatter(rtf, wer, s=size, c=color, edgecolors='black',
46
+ linewidths=edgewidth, zorder=3, alpha=0.8)
47
 
48
+ # Adjust text position for each model
49
  if name == "HebTTS":
50
+ x_text = rtf * 0.75
51
+ y_text = wer
52
  ha = 'right'
53
+ va = 'center'
54
  elif name == "Google":
55
+ x_text = rtf * 1.2
56
+ y_text = wer - 0.008
57
  ha = 'left'
58
+ va = 'center'
59
  elif name == 'LoTHM':
60
+ x_text = rtf * 0.85
61
+ y_text = wer
62
+ ha = 'right'
63
+ va = 'center'
64
  elif name == "OpenAI":
65
+ x_text = rtf
66
+ y_text = wer - 0.012
67
+ ha = 'center'
68
+ va = 'top'
69
+ elif name == "Robo-Shaul":
70
+ x_text = rtf * 1.2
71
+ y_text = wer + 0.008
72
+ ha = 'left'
73
+ va = 'center'
74
+ elif name == "Piper":
75
+ x_text = rtf * 0.9
76
+ y_text = wer - 0.022
77
+ ha = 'left'
78
+ va = 'top'
79
+ elif name == "StyleTTS2":
80
+ x_text = rtf * 0.8
81
+ y_text = wer - 0.018
82
+ ha = 'center'
83
+ va = 'top'
84
+ elif name == "SASPEECH":
85
+ x_text = rtf * 1.2
86
+ y_text = wer + 0.008
87
+ ha = 'left'
88
+ va = 'center'
89
+ elif name == "MMS":
90
+ x_text = rtf * 0.85
91
+ y_text = wer + 0.015
92
  ha = 'right'
93
+ va = 'bottom'
94
  else:
95
  x_text = rtf * 1.15
96
+ y_text = wer
97
  ha = 'left'
98
+ va = 'center'
99
 
100
  # Add text label for each point
101
+ fontsize = 20 if category == 'Ours' else 22
102
+ ax.text(x_text, y_text, label, fontsize=fontsize, ha=ha, va=va,
103
+ color='black', weight=weight, zorder=4)
104
 
105
  # Set x-axis to log scale and format it
106
  ax.set_xscale('log')
107
+ ax.tick_params(axis='both', which='major', labelsize=14)
108
+ ax.tick_params(axis='both', which='minor', labelsize=12)
109
  ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
110
  ax.xaxis.get_major_formatter().set_scientific(False)
111
  ax.xaxis.get_major_formatter().set_useOffset(False)
112
 
113
+ # Add minor ticks for better readability
114
+ ax.xaxis.set_minor_locator(ticker.LogLocator(base=10.0, subs=np.arange(2, 10) * 0.1))
 
115
 
116
+ # Set axis labels with larger font
117
+ ax.set_xlabel("RTF (lower is faster)", fontsize=18, fontweight='bold')
118
+ ax.set_ylabel("WER (lower is more accurate)", fontsize=18, fontweight='bold')
119
+
120
+ # Remove title
121
+
122
+ # Add subtle grid
123
+ ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
124
+ ax.set_axisbelow(True)
125
+
126
+ # Create custom legend
127
+ from matplotlib.patches import Patch
128
+ legend_elements = [
129
+ plt.scatter([], [], c=colors['Ours'], s=150, edgecolors='black',
130
+ linewidths=1.5, label='Our Models'),
131
+ plt.scatter([], [], c=colors['Open'], s=120, edgecolors='black',
132
+ linewidths=1.5, label='Open Source'),
133
+ plt.scatter([], [], c=colors['Proprietary'], s=120, edgecolors='black',
134
+ linewidths=1.5, label='Proprietary')
135
+ ]
136
+
137
+ ax.legend(handles=legend_elements, loc='upper right', fontsize=14,
138
+ frameon=True, fancybox=True, shadow=True)
139
 
140
  # Adjust layout to prevent labels from being cut off
141
  plt.tight_layout()
142
 
143
+ # Extend x-axis limits by 30% to make space for labels
144
  x_min, x_max = ax.get_xlim()
145
+ ax.set_xlim(x_min, x_max * 1.3)
146
+
147
+ # Extend y-axis limits slightly for better spacing
148
+ y_min, y_max = ax.get_ylim()
149
+ ax.set_ylim(y_min - 0.01, y_max + 0.02)
150
+
151
+ # Keep plot clean and simple
152
+
153
+ # Remove figure caption
154
+
155
+ # Save with high quality
156
+ plt.savefig("plot.png", dpi=300, bbox_inches='tight', facecolor='white')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  plt.show()