Alhdrawi commited on
Commit
89ecb88
·
verified ·
1 Parent(s): 9b84499

Upload eval.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. eval.py +229 -0
eval.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import numpy as np
3
+ import os
4
+ import pandas as pd
5
+ from PIL import Image
6
+ import h5py
7
+ import matplotlib.pyplot as plt
8
+ from typing import List, Callable
9
+
10
+ import torch
11
+ from torch.utils import data
12
+ from tqdm.notebook import tqdm
13
+ import torch.nn as nn
14
+ from torchvision.transforms import Compose, Normalize, Resize
15
+
16
+ import sklearn
17
+ from sklearn.metrics import matthews_corrcoef, confusion_matrix, accuracy_score, auc, roc_auc_score, roc_curve, classification_report
18
+ from sklearn.metrics import precision_recall_curve, f1_score
19
+ from sklearn.metrics import average_precision_score
20
+ from sklearn.utils import resample
21
+
22
+ import scipy
23
+ import scipy.stats
24
+
25
+ import sys
26
+ sys.path.append('../..')
27
+
28
+ import clip
29
+ from model import CLIP
30
+
31
+ def compute_mean(stats, is_df=True):
32
+ spec_labels = ["Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Pleural Effusion"]
33
+ if is_df:
34
+ spec_df = stats[spec_labels]
35
+ res = np.mean(spec_df.iloc[0])
36
+ else:
37
+ # cis is df, within bootstrap
38
+ vals = [stats[spec_label][0] for spec_label in spec_labels]
39
+ res = np.mean(vals)
40
+ return res
41
+
42
+ def accuracy(output, target, topk=(1,)):
43
+ pred = output.topk(max(topk), 1, True, True)[1].t()
44
+ print('pred: ', pred)
45
+
46
+ expand = target.expand(-1, max(topk))
47
+ print('expand: ', expand)
48
+
49
+ correct = pred.eq(expand)
50
+ print('correct: ', correct)
51
+ return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
52
+
53
+ def sigmoid(x):
54
+ z = 1/(1 + np.exp(-x))
55
+ return z
56
+
57
+ ''' ROC CURVE '''
58
+ def plot_roc(y_pred, y_true, roc_name, plot=False):
59
+ # given the test_ground_truth, and test_predictions
60
+ fpr, tpr, thresholds = roc_curve(y_true, y_pred)
61
+
62
+ roc_auc = auc(fpr, tpr)
63
+
64
+ if plot:
65
+ plt.figure(dpi=100)
66
+ plt.title(roc_name)
67
+ plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
68
+ plt.legend(loc = 'lower right')
69
+ plt.plot([0, 1], [0, 1],'r--')
70
+ plt.xlim([0, 1])
71
+ plt.ylim([0, 1])
72
+ plt.ylabel('True Positive Rate')
73
+ plt.xlabel('False Positive Rate')
74
+ plt.show()
75
+ return fpr, tpr, thresholds, roc_auc
76
+
77
+ # J = TP/(TP+FN) + TN/(TN+FP) - 1 = tpr - fpr
78
+ def choose_operating_point(fpr, tpr, thresholds):
79
+ sens = 0
80
+ spec = 0
81
+ J = 0
82
+ for _fpr, _tpr in zip(fpr, tpr):
83
+ if _tpr - _fpr > J:
84
+ sens = _tpr
85
+ spec = 1-_fpr
86
+ J = _tpr - _fpr
87
+ return sens, spec
88
+
89
+ ''' PRECISION-RECALL CURVE '''
90
+ def plot_pr(y_pred, y_true, pr_name, plot=False):
91
+ precision, recall, thresholds = precision_recall_curve(y_true, y_pred)
92
+ pr_auc = auc(recall, precision)
93
+ # plot the precision-recall curves
94
+ baseline = len(y_true[y_true==1]) / len(y_true)
95
+
96
+ if plot:
97
+ plt.figure(dpi=20)
98
+ plt.title(pr_name)
99
+ plt.plot(recall, precision, 'b', label='AUC = %0.2f' % pr_auc)
100
+ # axis labels
101
+ plt.legend(loc = 'lower right')
102
+ plt.plot([0, 1], [baseline, baseline],'r--')
103
+ plt.xlim([0, 1])
104
+ plt.ylim([0, 1])
105
+ plt.xlabel('Recall')
106
+ plt.ylabel('Precision')
107
+ # show the plot
108
+ plt.show()
109
+ return precision, recall, thresholds
110
+
111
+ def evaluate(y_pred, y_true, cxr_labels,
112
+ roc_name='Receiver Operating Characteristic', pr_name='Precision-Recall Curve', label_idx_map=None):
113
+
114
+ '''
115
+ We expect `y_pred` and `y_true` to be numpy arrays, both of shape (num_samples, num_classes)
116
+
117
+ `y_pred` is a numpy array consisting of probability scores with all values in range 0-1.
118
+
119
+ `y_true` is a numpy array consisting of binary values representing if a class is present in
120
+ the cxr.
121
+
122
+ This function provides all relevant evaluation information, ROC, AUROC, Sensitivity, Specificity,
123
+ PR-Curve, Precision, Recall for each class.
124
+ '''
125
+ import warnings
126
+ warnings.filterwarnings('ignore')
127
+
128
+ num_classes = y_pred.shape[-1] # number of total labels
129
+
130
+ dataframes = []
131
+ for i in range(num_classes):
132
+ # print('{}.'.format(cxr_labels[i]))
133
+
134
+ if label_idx_map is None:
135
+ y_pred_i = y_pred[:, i] # (num_samples,)
136
+ y_true_i = y_true[:, i] # (num_samples,)
137
+
138
+ else:
139
+ y_pred_i = y_pred[:, i] # (num_samples,)
140
+
141
+ true_index = label_idx_map[cxr_labels[i]]
142
+ y_true_i = y_true[:, true_index] # (num_samples,)
143
+
144
+ cxr_label = cxr_labels[i]
145
+
146
+ ''' ROC CURVE '''
147
+ roc_name = cxr_label + ' ROC Curve'
148
+ fpr, tpr, thresholds, roc_auc = plot_roc(y_pred_i, y_true_i, roc_name)
149
+
150
+ sens, spec = choose_operating_point(fpr, tpr, thresholds)
151
+
152
+ results = [[roc_auc]]
153
+ df = pd.DataFrame(results, columns=[cxr_label+'_auc'])
154
+ dataframes.append(df)
155
+
156
+ ''' PRECISION-RECALL CURVE '''
157
+ pr_name = cxr_label + ' Precision-Recall Curve'
158
+ precision, recall, thresholds = plot_pr(y_pred_i, y_true_i, pr_name)
159
+
160
+ dfs = pd.concat(dataframes, axis=1)
161
+ return dfs
162
+
163
+ ''' Bootstrap and Confidence Intervals '''
164
+ def compute_cis(data, confidence_level=0.05):
165
+ """
166
+ FUNCTION: compute_cis
167
+ ------------------------------------------------------
168
+ Given a Pandas dataframe of (n, labels), return another
169
+ Pandas dataframe that is (3, labels).
170
+
171
+ Each row is lower bound, mean, upper bound of a confidence
172
+ interval with `confidence`.
173
+
174
+ Args:
175
+ * data - Pandas Dataframe, of shape (num_bootstrap_samples, num_labels)
176
+ * confidence_level (optional) - confidence level of interval
177
+
178
+ Returns:
179
+ * Pandas Dataframe, of shape (3, labels), representing mean, lower, upper
180
+ """
181
+ data_columns = list(data)
182
+ intervals = []
183
+ for i in data_columns:
184
+ series = data[i]
185
+ sorted_perfs = series.sort_values()
186
+ lower_index = int(confidence_level/2 * len(sorted_perfs)) - 1
187
+ upper_index = int((1 - confidence_level/2) * len(sorted_perfs)) - 1
188
+ lower = sorted_perfs.iloc[lower_index].round(4)
189
+ upper = sorted_perfs.iloc[upper_index].round(4)
190
+ mean = round(sorted_perfs.mean(), 4)
191
+ interval = pd.DataFrame({i : [mean, lower, upper]})
192
+ intervals.append(interval)
193
+ intervals_df = pd.concat(intervals, axis=1)
194
+ intervals_df.index = ['mean', 'lower', 'upper']
195
+ return intervals_df
196
+
197
+ def bootstrap(y_pred, y_true, cxr_labels, n_samples=1000, label_idx_map=None):
198
+ '''
199
+ This function will randomly sample with replacement
200
+ from y_pred and y_true then evaluate `n` times
201
+ and obtain AUROC scores for each.
202
+
203
+ You can specify the number of samples that should be
204
+ used with the `n_samples` parameter.
205
+
206
+ Confidence intervals will be generated from each
207
+ of the samples.
208
+
209
+ Note:
210
+ * n_total_labels >= n_cxr_labels
211
+ `n_total_labels` is greater iff alternative labels are being tested
212
+ '''
213
+ np.random.seed(97)
214
+ y_pred # (500, n_total_labels)
215
+ y_true # (500, n_cxr_labels)
216
+
217
+ idx = np.arange(len(y_true))
218
+
219
+ boot_stats = []
220
+ for i in tqdm(range(n_samples)):
221
+ sample = resample(idx, replace=True, random_state=i)
222
+ y_pred_sample = y_pred[sample]
223
+ y_true_sample = y_true[sample]
224
+
225
+ sample_stats = evaluate(y_pred_sample, y_true_sample, cxr_labels, label_idx_map=label_idx_map)
226
+ boot_stats.append(sample_stats)
227
+
228
+ boot_stats = pd.concat(boot_stats) # pandas array of evaluations for each sample
229
+ return boot_stats, compute_cis(boot_stats)