jiehou commited on
Commit
84df785
·
verified ·
1 Parent(s): fb65f7b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -13
app.py CHANGED
@@ -231,6 +231,15 @@ def visualize_ROC(set_threshold,set_input):
231
  f1_score_list.append(f1_score)
232
  g_mean_list.append(g_mean)
233
 
 
 
 
 
 
 
 
 
 
234
  fig, ax = plt.subplots(figsize=(12,8))
235
 
236
  import matplotlib.pyplot as plt
@@ -240,17 +249,23 @@ def visualize_ROC(set_threshold,set_input):
240
  m1, c1 = 1, 0
241
  x = np.linspace(0, 1, 500)
242
 
 
243
  plt.plot(thres_list, f1_score_list, label = 'F1-score', c='black', linestyle='-')
244
  plt.plot(thres_list, g_mean_list, label = 'G-mean', c='red', linestyle='-')
245
 
246
  plt.xlim(0, 1)
247
  plt.ylim(0, 1)
248
- #xi = (c1 - c2) / (m2 - m1)
249
- #yi = m1 * xi + c1
250
- plt.axvline(x=set_threshold, color='gray', linestyle='--')
251
- plt.axhline(y=f1_score_cur, color='gray', linestyle='--')
252
- plt.scatter(set_threshold, f1_score_cur, color='red', s=300)
253
- plt.scatter(set_threshold, g_mean_cur, color='red', s=300)
 
 
 
 
 
254
 
255
  ax.set_facecolor("white")
256
 
@@ -259,13 +274,24 @@ def visualize_ROC(set_threshold,set_input):
259
  ax.spines['left'].set_color('black')
260
  ax.spines['bottom'].set_color('black')
261
  ax.spines['top'].set_color('black')
262
- ax.spines['right'].set_color('black')
263
  plt.xlabel('Threshold cut-off')
264
  plt.ylabel('F1-score & G-mean')
265
- plt.legend(loc='upper left')
266
- plt.text(set_threshold, f1_score_cur, 'F1-score:%s' % (round(f1_score_cur,2)))
267
- plt.text(set_threshold, g_mean_cur, 'G-mean:%s' % (round(g_mean_cur,2)))
268
- plt.title("Threshold tuning curves (F1-score & G-mean)", fontsize=20)
 
 
 
 
 
 
 
 
 
 
 
269
  plt.tight_layout()
270
 
271
  plt.savefig('tmp3.png', dpi=100)
@@ -301,10 +327,10 @@ interface = gr.Interface(fn=visualize_ROC,
301
  [0.7,get_example()],
302
  ],
303
  title="ML Demo for Receiver Operating Characteristic (ROC) curve",
304
- description= "Click examples below for a quick demo",
305
  theme = 'huggingface',
306
  #layout = 'horizontal',
307
  )
308
 
309
 
310
- interface.launch(debug=True, height=1400, width=2800)
 
231
  f1_score_list.append(f1_score)
232
  g_mean_list.append(g_mean)
233
 
234
+ # Find best thresholds
235
+ best_f1_idx = np.argmax(f1_score_list)
236
+ best_gmean_idx = np.argmax(g_mean_list)
237
+
238
+ best_f1_threshold = thres_list[best_f1_idx]
239
+ best_gmean_threshold = thres_list[best_gmean_idx]
240
+ best_f1_value = f1_score_list[best_f1_idx]
241
+ best_gmean_value = g_mean_list[best_gmean_idx]
242
+
243
  fig, ax = plt.subplots(figsize=(12,8))
244
 
245
  import matplotlib.pyplot as plt
 
249
  m1, c1 = 1, 0
250
  x = np.linspace(0, 1, 500)
251
 
252
+ # Plot curves
253
  plt.plot(thres_list, f1_score_list, label = 'F1-score', c='black', linestyle='-')
254
  plt.plot(thres_list, g_mean_list, label = 'G-mean', c='red', linestyle='-')
255
 
256
  plt.xlim(0, 1)
257
  plt.ylim(0, 1)
258
+
259
+ # Mark current threshold (user selected)
260
+ plt.axvline(x=set_threshold, color='blue', linestyle=':', linewidth=2, alpha=0.5, label='Current threshold')
261
+ plt.scatter(set_threshold, f1_score_cur, color='blue', s=200, alpha=0.5, marker='o')
262
+ plt.scatter(set_threshold, g_mean_cur, color='blue', s=200, alpha=0.5, marker='o')
263
+
264
+ # Mark BEST thresholds (optimal)
265
+ plt.scatter(best_f1_threshold, best_f1_value, color='black', s=400, marker='*',
266
+ edgecolors='gold', linewidths=2, zorder=5, label=f'Best F1 (threshold={best_f1_threshold:.2f})')
267
+ plt.scatter(best_gmean_threshold, best_gmean_value, color='red', s=400, marker='*',
268
+ edgecolors='gold', linewidths=2, zorder=5, label=f'Best G-mean (threshold={best_gmean_threshold:.2f})')
269
 
270
  ax.set_facecolor("white")
271
 
 
274
  ax.spines['left'].set_color('black')
275
  ax.spines['bottom'].set_color('black')
276
  ax.spines['top'].set_color('black')
277
+ ax.spines['right'].set_color='black')
278
  plt.xlabel('Threshold cut-off')
279
  plt.ylabel('F1-score & G-mean')
280
+ plt.legend(loc='upper right', fontsize=10)
281
+
282
+ # Add text annotations for best values
283
+ plt.text(best_f1_threshold, best_f1_value + 0.03, f'Best F1: {best_f1_value:.2f}',
284
+ ha='center', fontsize=10, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
285
+ plt.text(best_gmean_threshold, best_gmean_value + 0.03, f'Best G-mean: {best_gmean_value:.2f}',
286
+ ha='center', fontsize=10, bbox=dict(boxstyle='round', facecolor='lightcoral', alpha=0.5))
287
+
288
+ # Add text annotations for current values
289
+ plt.text(set_threshold, f1_score_cur - 0.05, f'Current F1: {f1_score_cur:.2f}',
290
+ ha='center', fontsize=9, color='blue', alpha=0.7)
291
+ plt.text(set_threshold, g_mean_cur - 0.05, f'Current G-mean: {g_mean_cur:.2f}',
292
+ ha='center', fontsize=9, color='blue', alpha=0.7)
293
+
294
+ plt.title("Threshold tuning curves (F1-score & G-mean)\nGold stars mark optimal thresholds", fontsize=20)
295
  plt.tight_layout()
296
 
297
  plt.savefig('tmp3.png', dpi=100)
 
327
  [0.7,get_example()],
328
  ],
329
  title="ML Demo for Receiver Operating Characteristic (ROC) curve",
330
+ description= "Click examples below for a quick demo. Gold stars show optimal F1 and G-mean thresholds.",
331
  theme = 'huggingface',
332
  #layout = 'horizontal',
333
  )
334
 
335
 
336
+ interface.launch(debug=True, height=1400, width=2800)