import pandas as pd import seaborn as sn import matplotlib.pyplot as plt from sklearn.metrics import confusion_matrix from matplotlib.colors import ListedColormap import numpy as np import gradio as gr set_input = gr.Dataframe(type="numpy", row_count=10, col_count=3, headers=['Sample Index', 'Predicted Prob', 'Label (Y)'], datatype=["number", "number", "number"]) set_input2 = gr.Slider(0, 1, step = 0.1, value=0.4, label="Set Probability Threshold (Default = 0.5)") #set_output = gr.Textbox(label ='test') set_output1 = gr.Dataframe(type="pandas", label = 'Predicted Labels') set_output2 = gr.Image(label="Confusion Matrix") set_output3 = gr.Image(label="ROC curve") set_output4 = gr.Image(label="Threshold Tuning curve") def perf_measure(y_actual, y_hat): TP = 0 FP = 0 TN = 0 FN = 0 for i in range(len(y_hat)): if y_actual[i]==y_hat[i]==1: TP += 1 if y_hat[i]==1 and y_actual[i]!=y_hat[i]: FP += 1 if y_actual[i]==y_hat[i]==0: TN += 1 if y_hat[i]==0 and y_actual[i]!=y_hat[i]: FN += 1 return(TP, FP, TN, FN) def visualize_ROC(set_threshold,set_input): import numpy as np prob = set_input[:,1] pred_label = (prob >= set_threshold).astype(int) actual_label = set_input[:,2] import pandas as pd data = { 'Predicted Prob': prob, 'Predicted Label': pred_label, 'Actual Label': actual_label } import pandas as pd import seaborn as sn import matplotlib.pyplot as plt df = pd.DataFrame(data) confusion_matrix_results = confusion_matrix(df['Actual Label'], df['Predicted Label']) fig, ax = plt.subplots(figsize=(12,4)) sn.heatmap(confusion_matrix_results, annot=True,annot_kws={"size": 20},cbar=False, square=False, fmt='g', cmap=ListedColormap(['white']), linecolor='black', linewidths=1.5) sn.set(font_scale=2) plt.xlabel("Predicted Label") plt.ylabel("Actual Label") plt.text(0.6,0.55,'(TN)') plt.text(1.6,0.55,'(FP)') plt.text(0.6,1.55,'(FN)') plt.text(1.6,1.55,'(TP)') ax.xaxis.tick_top() ax.xaxis.set_ticks_position('top') ax.xaxis.set_label_position('top') plt.tight_layout() plt.savefig('tmp.png', dpi=100) ## get ROC curve from sklearn.metrics import roc_curve fpr_mod, tpr_mod, thrsholds_mod = roc_curve(df['Actual Label'], df['Predicted Prob']) TP, FP, TN, FN = perf_measure(df['Actual Label'], df['Predicted Label']) # Sensitivity, hit rate, recall, or true positive rate try: recall = TP/(TP+FN) except: recall = 0 try: precision = TP/(TP+FP) except: precision = 0 try: specificity = TN/(TN+FP) except: specificity = 0 try: TPR = TP/(TP+FN) except: TPR = 0 # Fall out or false positive rate try: FPR = FP/(FP+TN) except: FPR = 0 try: f1_score_cur = 2*recall*precision/(precision+recall) except: f1_score_cur = 0 try: g_mean_cur = np.sqrt(recall*specificity) except: g_mean_cur = 0 fig, ax = plt.subplots(figsize=(12,8)) import matplotlib.pyplot as plt import numpy as np plt.rcParams["figure.autolayout"] = True plt.rcParams['figure.facecolor'] = 'white' m1, c1 = 1, 0 x = np.linspace(0, 1, 500) plt.plot(fpr_mod, tpr_mod, label = 'ROC', c='blue', linestyle='-') plt.plot(x, x * m1 + c1, 'black', linestyle='--') plt.xlim(0, 1) plt.ylim(0, 1) #xi = (c1 - c2) / (m2 - m1) #yi = m1 * xi + c1 plt.axvline(x=FPR, color='gray', linestyle='--') plt.axhline(y=TPR, color='gray', linestyle='--') plt.scatter(FPR, TPR, color='red', s=300) ax.set_facecolor("white") ax.tick_params(axis='x', colors='black') ax.tick_params(axis='y', colors='black') ax.spines['left'].set_color('black') ax.spines['bottom'].set_color('black') ax.spines['top'].set_color('black') ax.spines['right'].set_color('black') plt.xlabel('False Positive Rate (1 - specificity)') plt.ylabel('True Positive Rate (Recall)') plt.text(FPR, TPR, 'FPR:%s, TPR:%s' % (round(FPR,2),round(TPR,2))) plt.title("ROC curve", fontsize=20) plt.tight_layout() plt.savefig('tmp2.png', dpi=100) ### plot threshold versus f1-score thres_list = [] f1_score_list = [] g_mean_list = [] for thres in np.arange(0,1,0.01): prob = set_input[:,1] pred_label = (prob >= thres).astype(int) actual_label = set_input[:,2] import pandas as pd data = { 'Predicted Prob': prob, 'Predicted Label': pred_label, 'Actual Label': actual_label } df = pd.DataFrame(data) confusion_matrix_results = confusion_matrix(df['Actual Label'], df['Predicted Label']) TP, FP, TN, FN = perf_measure(df['Actual Label'], df['Predicted Label']) # Sensitivity, hit rate, recall, or true positive rate try: recall = TP/(TP+FN) except: recall = 0 try: precision = TP/(TP+FP) except: precision = 0 try: specificity = TN/(TN+FP) except: specificity = 0 try: TPR = TP/(TP+FN) except: TPR = 0 # Fall out or false positive rate try: FPR = FP/(FP+TN) except: FPR = 0 try: f1_score = 2*recall*precision/(precision+recall) except: f1_score = 0 try: g_mean = np.sqrt(recall*specificity) except: g_mean = 0 thres_list.append(thres) f1_score_list.append(f1_score) g_mean_list.append(g_mean) # Find best thresholds best_f1_idx = np.argmax(f1_score_list) best_gmean_idx = np.argmax(g_mean_list) best_f1_threshold = thres_list[best_f1_idx] best_gmean_threshold = thres_list[best_gmean_idx] best_f1_value = f1_score_list[best_f1_idx] best_gmean_value = g_mean_list[best_gmean_idx] fig, ax = plt.subplots(figsize=(12,8)) import matplotlib.pyplot as plt import numpy as np plt.rcParams["figure.autolayout"] = True plt.rcParams['figure.facecolor'] = 'white' m1, c1 = 1, 0 x = np.linspace(0, 1, 500) # Plot curves plt.plot(thres_list, f1_score_list, label = 'F1-score', c='black', linestyle='-') plt.plot(thres_list, g_mean_list, label = 'G-mean', c='red', linestyle='-') plt.xlim(0, 1) plt.ylim(0, 1) # Mark current threshold (user selected) plt.axvline(x=set_threshold, color='blue', linestyle=':', linewidth=2, alpha=0.5, label='Current threshold') plt.scatter(set_threshold, f1_score_cur, color='blue', s=200, alpha=0.5, marker='o') plt.scatter(set_threshold, g_mean_cur, color='blue', s=200, alpha=0.5, marker='o') # Mark BEST thresholds (optimal) plt.scatter(best_f1_threshold, best_f1_value, color='black', s=400, marker='*', edgecolors='gold', linewidths=2, zorder=5, label=f'Best F1 (threshold={best_f1_threshold:.2f})') plt.scatter(best_gmean_threshold, best_gmean_value, color='red', s=400, marker='*', edgecolors='gold', linewidths=2, zorder=5, label=f'Best G-mean (threshold={best_gmean_threshold:.2f})') ax.set_facecolor("white") ax.tick_params(axis='x', colors='black') ax.tick_params(axis='y', colors='black') ax.spines['left'].set_color('black') ax.spines['bottom'].set_color('black') ax.spines['top'].set_color('black') ax.spines['right'].set_color('black') plt.xlabel('Threshold cut-off') plt.ylabel('F1-score & G-mean') plt.legend(loc='upper right', fontsize=10) # Add text annotations for best values plt.text(best_f1_threshold, best_f1_value + 0.03, f'Best F1: {best_f1_value:.2f}', ha='center', fontsize=10, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) plt.text(best_gmean_threshold, best_gmean_value + 0.03, f'Best G-mean: {best_gmean_value:.2f}', ha='center', fontsize=10, bbox=dict(boxstyle='round', facecolor='lightcoral', alpha=0.5)) # Add text annotations for current values plt.text(set_threshold, f1_score_cur - 0.05, f'Current F1: {f1_score_cur:.2f}', ha='center', fontsize=9, color='blue', alpha=0.7) plt.text(set_threshold, g_mean_cur - 0.05, f'Current G-mean: {g_mean_cur:.2f}', ha='center', fontsize=9, color='blue', alpha=0.7) plt.title("Threshold tuning curves (F1-score & G-mean)\nGold stars mark optimal thresholds", fontsize=20) plt.tight_layout() plt.savefig('tmp3.png', dpi=100) #return df,'tmp.png','tmp2.png' return 'tmp.png','tmp2.png','tmp3.png' def get_example(): import numpy as np import pandas as pd np.random.seed(seed = 42) N=100 pd_class1 = pd.DataFrame({'Sample Index': [i for i in range(1,int(N/4)+1)],'Predicted Prob': np.random.uniform(0.4,0.8,int(N/4)), 'Label (Y)': np.repeat(1,int(N/4))}) pd_class2 = pd.DataFrame({'Sample Index': [i for i in range(int(N/4)+1,N+1)],'Predicted Prob': np.random.uniform(0,0.7,int(3*N/4)), 'Label (Y)': np.repeat(0,int(3*N/4))}) pd_all = pd.concat([pd_class1, pd_class2]).reset_index(drop=True) pd_all = pd_all.sample(frac=1).reset_index(drop=True) pd_all['Sample Index'] = [i for i in range(1,N+1)] return pd_all.to_numpy() ### configure Gradio interface = gr.Interface(fn=visualize_ROC, inputs=[set_input2, set_input], outputs=[set_output2,set_output3,set_output4], examples_per_page = 2, examples=[ [0.5,get_example()], [0.7,get_example()], ], title="ML Demo for Receiver Operating Characteristic (ROC) curve", description= "Click examples below for a quick demo. Gold stars show optimal F1 and G-mean thresholds.", theme = 'huggingface', #layout = 'horizontal', ) interface.launch(debug=True, height=1400, width=2800)