Demo_ROC_curves / app.py
jiehou's picture
Update app.py
691060a verified
import pandas as pd
import seaborn as sn
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from matplotlib.colors import ListedColormap
import numpy as np
import gradio as gr
set_input = gr.Dataframe(type="numpy", row_count=10, col_count=3, headers=['Sample Index', 'Predicted Prob', 'Label (Y)'], datatype=["number", "number", "number"])
set_input2 = gr.Slider(0, 1, step = 0.1, value=0.4, label="Set Probability Threshold (Default = 0.5)")
#set_output = gr.Textbox(label ='test')
set_output1 = gr.Dataframe(type="pandas", label = 'Predicted Labels')
set_output2 = gr.Image(label="Confusion Matrix")
set_output3 = gr.Image(label="ROC curve")
set_output4 = gr.Image(label="Threshold Tuning curve")
def perf_measure(y_actual, y_hat):
TP = 0
FP = 0
TN = 0
FN = 0
for i in range(len(y_hat)):
if y_actual[i]==y_hat[i]==1:
TP += 1
if y_hat[i]==1 and y_actual[i]!=y_hat[i]:
FP += 1
if y_actual[i]==y_hat[i]==0:
TN += 1
if y_hat[i]==0 and y_actual[i]!=y_hat[i]:
FN += 1
return(TP, FP, TN, FN)
def visualize_ROC(set_threshold,set_input):
import numpy as np
prob = set_input[:,1]
pred_label = (prob >= set_threshold).astype(int)
actual_label = set_input[:,2]
import pandas as pd
data = {
'Predicted Prob': prob,
'Predicted Label': pred_label,
'Actual Label': actual_label
}
import pandas as pd
import seaborn as sn
import matplotlib.pyplot as plt
df = pd.DataFrame(data)
confusion_matrix_results = confusion_matrix(df['Actual Label'], df['Predicted Label'])
fig, ax = plt.subplots(figsize=(12,4))
sn.heatmap(confusion_matrix_results, annot=True,annot_kws={"size": 20},cbar=False,
square=False,
fmt='g',
cmap=ListedColormap(['white']), linecolor='black',
linewidths=1.5)
sn.set(font_scale=2)
plt.xlabel("Predicted Label")
plt.ylabel("Actual Label")
plt.text(0.6,0.55,'(TN)')
plt.text(1.6,0.55,'(FP)')
plt.text(0.6,1.55,'(FN)')
plt.text(1.6,1.55,'(TP)')
ax.xaxis.tick_top()
ax.xaxis.set_ticks_position('top')
ax.xaxis.set_label_position('top')
plt.tight_layout()
plt.savefig('tmp.png', dpi=100)
## get ROC curve
from sklearn.metrics import roc_curve
fpr_mod, tpr_mod, thrsholds_mod = roc_curve(df['Actual Label'], df['Predicted Prob'])
TP, FP, TN, FN = perf_measure(df['Actual Label'], df['Predicted Label'])
# Sensitivity, hit rate, recall, or true positive rate
try:
recall = TP/(TP+FN)
except:
recall = 0
try:
precision = TP/(TP+FP)
except:
precision = 0
try:
specificity = TN/(TN+FP)
except:
specificity = 0
try:
TPR = TP/(TP+FN)
except:
TPR = 0
# Fall out or false positive rate
try:
FPR = FP/(FP+TN)
except:
FPR = 0
try:
f1_score_cur = 2*recall*precision/(precision+recall)
except:
f1_score_cur = 0
try:
g_mean_cur = np.sqrt(recall*specificity)
except:
g_mean_cur = 0
fig, ax = plt.subplots(figsize=(12,8))
import matplotlib.pyplot as plt
import numpy as np
plt.rcParams["figure.autolayout"] = True
plt.rcParams['figure.facecolor'] = 'white'
m1, c1 = 1, 0
x = np.linspace(0, 1, 500)
plt.plot(fpr_mod, tpr_mod, label = 'ROC', c='blue', linestyle='-')
plt.plot(x, x * m1 + c1, 'black', linestyle='--')
plt.xlim(0, 1)
plt.ylim(0, 1)
#xi = (c1 - c2) / (m2 - m1)
#yi = m1 * xi + c1
plt.axvline(x=FPR, color='gray', linestyle='--')
plt.axhline(y=TPR, color='gray', linestyle='--')
plt.scatter(FPR, TPR, color='red', s=300)
ax.set_facecolor("white")
ax.tick_params(axis='x', colors='black')
ax.tick_params(axis='y', colors='black')
ax.spines['left'].set_color('black')
ax.spines['bottom'].set_color('black')
ax.spines['top'].set_color('black')
ax.spines['right'].set_color('black')
plt.xlabel('False Positive Rate (1 - specificity)')
plt.ylabel('True Positive Rate (Recall)')
plt.text(FPR, TPR, 'FPR:%s, TPR:%s' % (round(FPR,2),round(TPR,2)))
plt.title("ROC curve", fontsize=20)
plt.tight_layout()
plt.savefig('tmp2.png', dpi=100)
### plot threshold versus f1-score
thres_list = []
f1_score_list = []
g_mean_list = []
for thres in np.arange(0,1,0.01):
prob = set_input[:,1]
pred_label = (prob >= thres).astype(int)
actual_label = set_input[:,2]
import pandas as pd
data = {
'Predicted Prob': prob,
'Predicted Label': pred_label,
'Actual Label': actual_label
}
df = pd.DataFrame(data)
confusion_matrix_results = confusion_matrix(df['Actual Label'], df['Predicted Label'])
TP, FP, TN, FN = perf_measure(df['Actual Label'], df['Predicted Label'])
# Sensitivity, hit rate, recall, or true positive rate
try:
recall = TP/(TP+FN)
except:
recall = 0
try:
precision = TP/(TP+FP)
except:
precision = 0
try:
specificity = TN/(TN+FP)
except:
specificity = 0
try:
TPR = TP/(TP+FN)
except:
TPR = 0
# Fall out or false positive rate
try:
FPR = FP/(FP+TN)
except:
FPR = 0
try:
f1_score = 2*recall*precision/(precision+recall)
except:
f1_score = 0
try:
g_mean = np.sqrt(recall*specificity)
except:
g_mean = 0
thres_list.append(thres)
f1_score_list.append(f1_score)
g_mean_list.append(g_mean)
# Find best thresholds
best_f1_idx = np.argmax(f1_score_list)
best_gmean_idx = np.argmax(g_mean_list)
best_f1_threshold = thres_list[best_f1_idx]
best_gmean_threshold = thres_list[best_gmean_idx]
best_f1_value = f1_score_list[best_f1_idx]
best_gmean_value = g_mean_list[best_gmean_idx]
fig, ax = plt.subplots(figsize=(12,8))
import matplotlib.pyplot as plt
import numpy as np
plt.rcParams["figure.autolayout"] = True
plt.rcParams['figure.facecolor'] = 'white'
m1, c1 = 1, 0
x = np.linspace(0, 1, 500)
# Plot curves
plt.plot(thres_list, f1_score_list, label = 'F1-score', c='black', linestyle='-')
plt.plot(thres_list, g_mean_list, label = 'G-mean', c='red', linestyle='-')
plt.xlim(0, 1)
plt.ylim(0, 1)
# Mark current threshold (user selected)
plt.axvline(x=set_threshold, color='blue', linestyle=':', linewidth=2, alpha=0.5, label='Current threshold')
plt.scatter(set_threshold, f1_score_cur, color='blue', s=200, alpha=0.5, marker='o')
plt.scatter(set_threshold, g_mean_cur, color='blue', s=200, alpha=0.5, marker='o')
# Mark BEST thresholds (optimal)
plt.scatter(best_f1_threshold, best_f1_value, color='black', s=400, marker='*',
edgecolors='gold', linewidths=2, zorder=5, label=f'Best F1 (threshold={best_f1_threshold:.2f})')
plt.scatter(best_gmean_threshold, best_gmean_value, color='red', s=400, marker='*',
edgecolors='gold', linewidths=2, zorder=5, label=f'Best G-mean (threshold={best_gmean_threshold:.2f})')
ax.set_facecolor("white")
ax.tick_params(axis='x', colors='black')
ax.tick_params(axis='y', colors='black')
ax.spines['left'].set_color('black')
ax.spines['bottom'].set_color('black')
ax.spines['top'].set_color('black')
ax.spines['right'].set_color('black')
plt.xlabel('Threshold cut-off')
plt.ylabel('F1-score & G-mean')
plt.legend(loc='upper right', fontsize=10)
# Add text annotations for best values
plt.text(best_f1_threshold, best_f1_value + 0.03, f'Best F1: {best_f1_value:.2f}',
ha='center', fontsize=10, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
plt.text(best_gmean_threshold, best_gmean_value + 0.03, f'Best G-mean: {best_gmean_value:.2f}',
ha='center', fontsize=10, bbox=dict(boxstyle='round', facecolor='lightcoral', alpha=0.5))
# Add text annotations for current values
plt.text(set_threshold, f1_score_cur - 0.05, f'Current F1: {f1_score_cur:.2f}',
ha='center', fontsize=9, color='blue', alpha=0.7)
plt.text(set_threshold, g_mean_cur - 0.05, f'Current G-mean: {g_mean_cur:.2f}',
ha='center', fontsize=9, color='blue', alpha=0.7)
plt.title("Threshold tuning curves (F1-score & G-mean)\nGold stars mark optimal thresholds", fontsize=20)
plt.tight_layout()
plt.savefig('tmp3.png', dpi=100)
#return df,'tmp.png','tmp2.png'
return 'tmp.png','tmp2.png','tmp3.png'
def get_example():
import numpy as np
import pandas as pd
np.random.seed(seed = 42)
N=100
pd_class1 = pd.DataFrame({'Sample Index': [i for i in range(1,int(N/4)+1)],'Predicted Prob': np.random.uniform(0.4,0.8,int(N/4)), 'Label (Y)': np.repeat(1,int(N/4))})
pd_class2 = pd.DataFrame({'Sample Index': [i for i in range(int(N/4)+1,N+1)],'Predicted Prob': np.random.uniform(0,0.7,int(3*N/4)), 'Label (Y)': np.repeat(0,int(3*N/4))})
pd_all = pd.concat([pd_class1, pd_class2]).reset_index(drop=True)
pd_all = pd_all.sample(frac=1).reset_index(drop=True)
pd_all['Sample Index'] = [i for i in range(1,N+1)]
return pd_all.to_numpy()
### configure Gradio
interface = gr.Interface(fn=visualize_ROC,
inputs=[set_input2, set_input],
outputs=[set_output2,set_output3,set_output4],
examples_per_page = 2,
examples=[
[0.5,get_example()],
[0.7,get_example()],
],
title="ML Demo for Receiver Operating Characteristic (ROC) curve",
description= "Click examples below for a quick demo. Gold stars show optimal F1 and G-mean thresholds.",
theme = 'huggingface',
#layout = 'horizontal',
)
interface.launch(debug=True, height=1400, width=2800)