Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -231,6 +231,15 @@ def visualize_ROC(set_threshold,set_input):
|
|
| 231 |
f1_score_list.append(f1_score)
|
| 232 |
g_mean_list.append(g_mean)
|
| 233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
fig, ax = plt.subplots(figsize=(12,8))
|
| 235 |
|
| 236 |
import matplotlib.pyplot as plt
|
|
@@ -240,17 +249,23 @@ def visualize_ROC(set_threshold,set_input):
|
|
| 240 |
m1, c1 = 1, 0
|
| 241 |
x = np.linspace(0, 1, 500)
|
| 242 |
|
|
|
|
| 243 |
plt.plot(thres_list, f1_score_list, label = 'F1-score', c='black', linestyle='-')
|
| 244 |
plt.plot(thres_list, g_mean_list, label = 'G-mean', c='red', linestyle='-')
|
| 245 |
|
| 246 |
plt.xlim(0, 1)
|
| 247 |
plt.ylim(0, 1)
|
| 248 |
-
|
| 249 |
-
#
|
| 250 |
-
plt.axvline(x=set_threshold, color='
|
| 251 |
-
plt.
|
| 252 |
-
plt.scatter(set_threshold,
|
| 253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
ax.set_facecolor("white")
|
| 256 |
|
|
@@ -259,13 +274,24 @@ def visualize_ROC(set_threshold,set_input):
|
|
| 259 |
ax.spines['left'].set_color('black')
|
| 260 |
ax.spines['bottom'].set_color('black')
|
| 261 |
ax.spines['top'].set_color('black')
|
| 262 |
-
ax.spines['right'].set_color
|
| 263 |
plt.xlabel('Threshold cut-off')
|
| 264 |
plt.ylabel('F1-score & G-mean')
|
| 265 |
-
plt.legend(loc='upper
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
plt.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
plt.tight_layout()
|
| 270 |
|
| 271 |
plt.savefig('tmp3.png', dpi=100)
|
|
@@ -301,10 +327,10 @@ interface = gr.Interface(fn=visualize_ROC,
|
|
| 301 |
[0.7,get_example()],
|
| 302 |
],
|
| 303 |
title="ML Demo for Receiver Operating Characteristic (ROC) curve",
|
| 304 |
-
description= "Click examples below for a quick demo",
|
| 305 |
theme = 'huggingface',
|
| 306 |
#layout = 'horizontal',
|
| 307 |
)
|
| 308 |
|
| 309 |
|
| 310 |
-
interface.launch(debug=True, height=1400, width=2800)
|
|
|
|
| 231 |
f1_score_list.append(f1_score)
|
| 232 |
g_mean_list.append(g_mean)
|
| 233 |
|
| 234 |
+
# Find best thresholds
|
| 235 |
+
best_f1_idx = np.argmax(f1_score_list)
|
| 236 |
+
best_gmean_idx = np.argmax(g_mean_list)
|
| 237 |
+
|
| 238 |
+
best_f1_threshold = thres_list[best_f1_idx]
|
| 239 |
+
best_gmean_threshold = thres_list[best_gmean_idx]
|
| 240 |
+
best_f1_value = f1_score_list[best_f1_idx]
|
| 241 |
+
best_gmean_value = g_mean_list[best_gmean_idx]
|
| 242 |
+
|
| 243 |
fig, ax = plt.subplots(figsize=(12,8))
|
| 244 |
|
| 245 |
import matplotlib.pyplot as plt
|
|
|
|
| 249 |
m1, c1 = 1, 0
|
| 250 |
x = np.linspace(0, 1, 500)
|
| 251 |
|
| 252 |
+
# Plot curves
|
| 253 |
plt.plot(thres_list, f1_score_list, label = 'F1-score', c='black', linestyle='-')
|
| 254 |
plt.plot(thres_list, g_mean_list, label = 'G-mean', c='red', linestyle='-')
|
| 255 |
|
| 256 |
plt.xlim(0, 1)
|
| 257 |
plt.ylim(0, 1)
|
| 258 |
+
|
| 259 |
+
# Mark current threshold (user selected)
|
| 260 |
+
plt.axvline(x=set_threshold, color='blue', linestyle=':', linewidth=2, alpha=0.5, label='Current threshold')
|
| 261 |
+
plt.scatter(set_threshold, f1_score_cur, color='blue', s=200, alpha=0.5, marker='o')
|
| 262 |
+
plt.scatter(set_threshold, g_mean_cur, color='blue', s=200, alpha=0.5, marker='o')
|
| 263 |
+
|
| 264 |
+
# Mark BEST thresholds (optimal)
|
| 265 |
+
plt.scatter(best_f1_threshold, best_f1_value, color='black', s=400, marker='*',
|
| 266 |
+
edgecolors='gold', linewidths=2, zorder=5, label=f'Best F1 (threshold={best_f1_threshold:.2f})')
|
| 267 |
+
plt.scatter(best_gmean_threshold, best_gmean_value, color='red', s=400, marker='*',
|
| 268 |
+
edgecolors='gold', linewidths=2, zorder=5, label=f'Best G-mean (threshold={best_gmean_threshold:.2f})')
|
| 269 |
|
| 270 |
ax.set_facecolor("white")
|
| 271 |
|
|
|
|
| 274 |
ax.spines['left'].set_color('black')
|
| 275 |
ax.spines['bottom'].set_color('black')
|
| 276 |
ax.spines['top'].set_color('black')
|
| 277 |
+
ax.spines['right'].set_color='black')
|
| 278 |
plt.xlabel('Threshold cut-off')
|
| 279 |
plt.ylabel('F1-score & G-mean')
|
| 280 |
+
plt.legend(loc='upper right', fontsize=10)
|
| 281 |
+
|
| 282 |
+
# Add text annotations for best values
|
| 283 |
+
plt.text(best_f1_threshold, best_f1_value + 0.03, f'Best F1: {best_f1_value:.2f}',
|
| 284 |
+
ha='center', fontsize=10, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
| 285 |
+
plt.text(best_gmean_threshold, best_gmean_value + 0.03, f'Best G-mean: {best_gmean_value:.2f}',
|
| 286 |
+
ha='center', fontsize=10, bbox=dict(boxstyle='round', facecolor='lightcoral', alpha=0.5))
|
| 287 |
+
|
| 288 |
+
# Add text annotations for current values
|
| 289 |
+
plt.text(set_threshold, f1_score_cur - 0.05, f'Current F1: {f1_score_cur:.2f}',
|
| 290 |
+
ha='center', fontsize=9, color='blue', alpha=0.7)
|
| 291 |
+
plt.text(set_threshold, g_mean_cur - 0.05, f'Current G-mean: {g_mean_cur:.2f}',
|
| 292 |
+
ha='center', fontsize=9, color='blue', alpha=0.7)
|
| 293 |
+
|
| 294 |
+
plt.title("Threshold tuning curves (F1-score & G-mean)\nGold stars mark optimal thresholds", fontsize=20)
|
| 295 |
plt.tight_layout()
|
| 296 |
|
| 297 |
plt.savefig('tmp3.png', dpi=100)
|
|
|
|
| 327 |
[0.7,get_example()],
|
| 328 |
],
|
| 329 |
title="ML Demo for Receiver Operating Characteristic (ROC) curve",
|
| 330 |
+
description= "Click examples below for a quick demo. Gold stars show optimal F1 and G-mean thresholds.",
|
| 331 |
theme = 'huggingface',
|
| 332 |
#layout = 'horizontal',
|
| 333 |
)
|
| 334 |
|
| 335 |
|
| 336 |
+
interface.launch(debug=True, height=1400, width=2800)
|