Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import sys, tqdm | |
| import torch | |
| import torch.nn.functional as F | |
| from numpy import interp | |
| from collections.abc import Sequence | |
| from collections import defaultdict | |
| from sklearn.metrics import confusion_matrix, roc_curve, roc_auc_score, auc | |
| from sklearn.metrics import precision_recall_curve, average_precision_score | |
| from sklearn.metrics import balanced_accuracy_score, precision_score | |
| import warnings | |
| import inspect | |
| _depth = lambda L: isinstance(L, (Sequence, np.ndarray)) and max(map(_depth, L)) + 1 | |
| def get_metrics(y_true, y_pred, scores, mask): | |
| ''' ... ''' | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| masked_y_true = y_true[np.where(mask == 1)] | |
| masked_y_pred = y_pred[np.where(mask == 1)] | |
| masked_scores = scores[np.where(mask == 1)] | |
| # metrics that are based on predictions | |
| try: | |
| cnf = confusion_matrix(masked_y_true, masked_y_pred) | |
| TN, FP, FN, TP = cnf.ravel() | |
| TNR = TN / (TN + FP) | |
| FPR = FP / (FP + TN) | |
| FNR = FN / (FN + TP) | |
| TPR = TP / (TP + FN) | |
| N = TN + TP + FN + FP | |
| S = (TP + FN) / N | |
| P = (TP + FP) / N | |
| acc = (TN + TP) / N | |
| sen = TP / (TP + FN) | |
| spc = TN / (TN + FP) | |
| prc = TP / (TP + FP) | |
| f1s = 2 * (prc * sen) / (prc + sen) | |
| mcc = (TP / N - S * P) / np.sqrt(P * S * (1 - S) * (1 - P)) | |
| NPV = TN / (TN + FN) | |
| # metrics that are based on scores, | |
| try: | |
| auc_roc = roc_auc_score(masked_y_true, masked_scores) | |
| except: | |
| auc_roc = 0 | |
| try: | |
| auc_pr = average_precision_score(masked_y_true, masked_scores) | |
| except: | |
| auc_pr = 0 | |
| bal_acc = balanced_accuracy_score(masked_y_true, masked_y_pred) | |
| except: | |
| cnf, acc, bal_acc, prc, sen, spc, f1s, mcc, auc_roc, auc_pr = -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 | |
| # construct the dictionary of all metrics | |
| met = {} | |
| met['Confusion Matrix'] = cnf | |
| met['Accuracy'] = acc | |
| met['Balanced Accuracy'] = bal_acc | |
| met['Precision'] = prc | |
| met['Sensitivity/Recall'] = sen | |
| met['Specificity'] = spc | |
| met['F1 score'] = f1s | |
| met['MCC'] = mcc | |
| met['AUC (ROC)'] = auc_roc | |
| met['AUC (PR)'] = auc_pr | |
| met['NPV'] = NPV | |
| return met | |
| def get_metrics_multitask(y_true, y_pred, scores, mask): | |
| ''' ... ''' | |
| if type(y_true) is dict: | |
| met: dict[str, dict[str, float]] = dict() | |
| for k in y_true.keys(): | |
| met[k] = get_metrics(y_true[k], y_pred[k], scores[k], mask[k]) | |
| else: | |
| met = [] | |
| for i in range(len(y_true[0])): | |
| met.append(get_metrics(y_true[:, i], y_pred[:, i], scores[:, i], mask[:, i])) | |
| return met | |
| def print_metrics(met): | |
| ''' ... ''' | |
| for k, v in met.items(): | |
| if k not in ['Confusion Matrix']: | |
| print('{}:\t{:.4f}'.format(k, v).expandtabs(20)) | |
| def print_metrics_multitask(met): | |
| ''' ... ''' | |
| if type(met) is dict: | |
| lbl_ks = list(met.keys()) | |
| met_ks = met[lbl_ks[0]].keys() | |
| for met_k in met_ks: | |
| if met_k not in ['Confusion Matrix']: | |
| msg = '{}:\t' + '{:.4f} ' * len(met) | |
| val = [met[lbl_k][met_k] for lbl_k in lbl_ks] | |
| msg = msg.format(met_k, *val) | |
| msg = msg.replace('nan', '------') | |
| print(msg.expandtabs(20)) | |
| else: | |
| for k in met[0]: | |
| if k not in ['Confusion Matrix']: | |
| msg = '{}:\t' + '{:.4f} ' * len(met) | |
| val = [met[i][k] for i in range(len(met))] | |
| msg = msg.format(k, *val) | |
| msg = msg.replace('nan', '------') | |
| print(msg.expandtabs(20)) | |
| def pr_interp(rc_, rc, pr): | |
| pr_ = np.zeros_like(rc_) | |
| locs = np.searchsorted(rc, rc_) | |
| for idx, loc in enumerate(locs): | |
| l = loc - 1 | |
| r = loc | |
| r1 = rc[l] if l > -1 else 0 | |
| r2 = rc[r] if r < len(rc) else 1 | |
| p1 = pr[l] if l > -1 else 1 | |
| p2 = pr[r] if r < len(rc) else 0 | |
| t1 = (1 - p2) * r2 / p2 / (r2 - r1) if p2 * (r2 - r1) > 1e-16 else (1 - p2) * r2 / 1e-16 | |
| t2 = (1 - p1) * r1 / p1 / (r2 - r1) if p1 * (r2 - r1) > 1e-16 else (1 - p1) * r1 / 1e-16 | |
| t3 = (1 - p1) * r1 / p1 if p1 > 1e-16 else (1 - p1) * r1 / 1e-16 | |
| a = 1 + t1 - t2 | |
| b = t3 - t1 * r1 + t2 * r1 | |
| pr_[idx] = rc_[idx] / (a * rc_[idx] + b) | |
| return pr_ | |
| def get_roc_info(y_true_all, scores_all): | |
| fpr_pt = np.linspace(0, 1, 1001) | |
| tprs, aucs = [], [] | |
| for i in range(len(y_true_all)): | |
| y_true = y_true_all[i] | |
| scores = scores_all[i] | |
| fpr, tpr, _ = roc_curve(y_true=y_true, y_score=scores, drop_intermediate=True) | |
| tprs.append(interp(fpr_pt, fpr, tpr)) | |
| tprs[-1][0] = 0.0 | |
| aucs.append(auc(fpr, tpr)) | |
| tprs_mean = np.mean(tprs, axis=0) | |
| tprs_std = np.std(tprs, axis=0) | |
| tprs_upper = np.minimum(tprs_mean + tprs_std, 1) | |
| tprs_lower = np.maximum(tprs_mean - tprs_std, 0) | |
| auc_mean = auc(fpr_pt, tprs_mean) | |
| auc_std = np.std(aucs) | |
| auc_std = 1 - auc_mean if auc_mean + auc_std > 1 else auc_std | |
| rslt = { | |
| 'xs': fpr_pt, | |
| 'ys_mean': tprs_mean, | |
| 'ys_upper': tprs_upper, | |
| 'ys_lower': tprs_lower, | |
| 'auc_mean': auc_mean, | |
| 'auc_std': auc_std | |
| } | |
| return rslt | |
| def get_pr_info(y_true_all, scores_all): | |
| rc_pt = np.linspace(0, 1, 1001) | |
| rc_pt[0] = 1e-16 | |
| prs = [] | |
| aps = [] | |
| for i in range(len(y_true_all)): | |
| y_true = y_true_all[i] | |
| scores = scores_all[i] | |
| pr, rc, _ = precision_recall_curve(y_true=y_true, probas_pred=scores) | |
| aps.append(average_precision_score(y_true=y_true, y_score=scores)) | |
| pr, rc = pr[::-1], rc[::-1] | |
| prs.append(pr_interp(rc_pt, rc, pr)) | |
| prs_mean = np.mean(prs, axis=0) | |
| prs_std = np.std(prs, axis=0) | |
| prs_upper = np.minimum(prs_mean + prs_std, 1) | |
| prs_lower = np.maximum(prs_mean - prs_std, 0) | |
| aps_mean = np.mean(aps) | |
| aps_std = np.std(aps) | |
| aps_std = 1 - aps_mean if aps_mean + aps_std > 1 else aps_std | |
| rslt = { | |
| 'xs': rc_pt, | |
| 'ys_mean': prs_mean, | |
| 'ys_upper': prs_upper, | |
| 'ys_lower': prs_lower, | |
| 'auc_mean': aps_mean, | |
| 'auc_std': aps_std | |
| } | |
| return rslt | |
| def get_and_print_metrics(mdl, dat): | |
| ''' ... ''' | |
| y_pred = mdl.predict(dat.x) | |
| y_prob = mdl.predict_proba(dat.x) | |
| met_all = get_metrics(dat.y, y_pred, y_prob) | |
| for k, v in met_all.items(): | |
| if k not in ['Confusion Matrix']: | |
| print('{}:\t{:.4f}'.format(k, v).expandtabs(20)) | |
| def get_and_print_metrics_multitask(mdl, dat): | |
| ''' ... ''' | |
| y_pred = mdl.predict(dat.x) | |
| y_prob = mdl.predict_proba(dat.x) | |
| met = get_metrics_multitask(dat.y, y_pred, y_prob) | |
| print_metrics_multitask(met) | |
| def split_dataset(dat, ratio=.8, seed=0): | |
| len_trn = int(np.round(len(dat) * .8)) | |
| len_vld = len(dat) - len_trn | |
| dat_trn, dat_vld = torch.utils.data.random_split( | |
| dat, (len_trn, len_vld), | |
| generator=torch.Generator().manual_seed(0) | |
| ) | |
| return dat_trn, dat_vld | |
| def l1_regularizer(model, lambda_l1=0.01): | |
| ''' LASSO ''' | |
| lossl1 = 0 | |
| for model_param_name, model_param_value in model.named_parameters(): | |
| if model_param_name.endswith('weight'): | |
| lossl1 += lambda_l1 * model_param_value.abs().sum() | |
| return lossl1 | |
| class ProgressBar(tqdm.tqdm): | |
| def __init__(self, total, desc, file=sys.stdout): | |
| super().__init__(total=total, desc=desc, ascii=True, bar_format='{l_bar}{r_bar}', file=file) | |
| def update(self, batch_size, to_disp): | |
| postfix = {} | |
| for k, v in to_disp.items(): | |
| if k == 'cnf': | |
| postfix[k] = v.__repr__().replace('\n', '') | |
| else: | |
| postfix[k] = '{:.6f}'.format(v.cpu().numpy()) | |
| self.set_postfix(postfix) | |
| super().update(batch_size) | |
| def _get_gt_mask(logits, target): | |
| target = target.reshape(-1) | |
| mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool() | |
| return mask | |
| def _get_other_mask(logits, target): | |
| target = target.reshape(-1) | |
| mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool() | |
| return mask | |
| def cat_mask(t, mask1, mask2): | |
| t1 = (t * mask1).sum(dim=1, keepdims=True) | |
| t2 = (t * mask2).sum(1, keepdims=True) | |
| rt = torch.cat([t1, t2], dim=1) | |
| return rt | |
| def dkd_loss(logits_student, logits_teacher, target, alpha, beta, temperature): | |
| gt_mask = _get_gt_mask(logits_student, target) | |
| other_mask = _get_other_mask(logits_student, target) | |
| pred_student = F.softmax(logits_student / temperature, dim=1) | |
| pred_teacher = F.softmax(logits_teacher / temperature, dim=1) | |
| pred_student = cat_mask(pred_student, gt_mask, other_mask) | |
| pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask) | |
| log_pred_student = torch.log(pred_student) | |
| tckd_loss = ( | |
| F.kl_div(log_pred_student, pred_teacher, size_average=False) | |
| * (temperature**2) | |
| / target.shape[0] | |
| ) | |
| pred_teacher_part2 = F.softmax( | |
| logits_teacher / temperature - 1000.0 * gt_mask, dim=1 | |
| ) | |
| log_pred_student_part2 = F.log_softmax( | |
| logits_student / temperature - 1000.0 * gt_mask, dim=1 | |
| ) | |
| nckd_loss = ( | |
| F.kl_div(log_pred_student_part2, pred_teacher_part2, size_average=False) | |
| * (temperature**2) | |
| / target.shape[0] | |
| ) | |
| return alpha * tckd_loss + beta * nckd_loss | |
| def convert_args_kwargs_to_kwargs(func, args, kwargs): | |
| """ ... """ | |
| signature = inspect.signature(func) | |
| bound_args = signature.bind(*args, **kwargs) | |
| bound_args.apply_defaults() | |
| return bound_args.arguments |