| import logging |
| import math |
| import pickle |
| from pathlib import Path |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
| import pandas as pd |
| import seaborn as sns |
| import torch |
| import datasets |
| from datasets.utils.logging import disable_progress_bar, enable_progress_bar |
| from sklearn import preprocessing |
| from sklearn.metrics import ( |
| ConfusionMatrixDisplay, |
| accuracy_score, |
| auc, |
| confusion_matrix, |
| f1_score, |
| roc_curve, |
| ) |
| from tqdm.auto import trange |
|
|
| from .emb_extractor import make_colorbar |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def preprocess_classifier_batch(cell_batch, max_len, label_name, gene_token_dict): |
| if max_len is None: |
| max_len = max([len(i) for i in cell_batch["input_ids"]]) |
|
|
| def pad_label_example(example): |
| example[label_name] = np.pad( |
| example[label_name], |
| (0, max_len - len(example["input_ids"])), |
| mode="constant", |
| constant_values=-100, |
| ) |
| example["input_ids"] = np.pad( |
| example["input_ids"], |
| (0, max_len - len(example["input_ids"])), |
| mode="constant", |
| constant_values=gene_token_dict.get("<pad>"), |
| ) |
| example["attention_mask"] = ( |
| example["input_ids"] != gene_token_dict.get("<pad>") |
| ).astype(int) |
| return example |
|
|
| padded_batch = cell_batch.map(pad_label_example) |
| return padded_batch |
|
|
|
|
| |
| |
| def find_largest_div(N, K): |
| rem = N % K |
| if rem == 0: |
| return N |
| else: |
| return N - rem |
|
|
|
|
| def vote(logit_list): |
| m = max(logit_list) |
| logit_list.index(m) |
| indices = [i for i, x in enumerate(logit_list) if x == m] |
| if len(indices) > 1: |
| return "tie" |
| else: |
| return indices[0] |
|
|
|
|
| def py_softmax(vector): |
| e = np.exp(vector) |
| return e / e.sum() |
|
|
|
|
| def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene_token_dict, predict_metadata=None): |
| if classifier_type == "gene": |
| label_name = "labels" |
| elif classifier_type == "cell": |
| label_name = "label" |
|
|
| predict_logits = [] |
| predict_labels = [] |
|
|
| predict_metadata_all = None |
| |
| if predict_metadata is not None: |
| predict_metadata_all = dict() |
| for metadata_name in predict_metadata: |
| predict_metadata_all[metadata_name] = [] |
| |
| model.eval() |
|
|
| |
| evalset_len = len(evalset) |
| max_divisible = find_largest_div(evalset_len, forward_batch_size) |
| if len(evalset) - max_divisible == 1: |
| evalset_len = max_divisible |
|
|
| max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"]) |
|
|
| disable_progress_bar() |
| for i in trange(0, evalset_len, forward_batch_size): |
| max_range = min(i + forward_batch_size, evalset_len) |
| batch_evalset = evalset.select([i for i in range(i, max_range)]) |
|
|
| if predict_metadata is not None: |
| for metadata_name in predict_metadata: |
| predict_metadata_all[metadata_name] += batch_evalset[metadata_name] |
| |
| padded_batch = preprocess_classifier_batch( |
| batch_evalset, max_evalset_len, label_name, gene_token_dict |
| ) |
| |
| padded_batch.set_format(type="torch") |
|
|
| |
| if int(datasets.__version__.split(".")[0]) >= 4: |
| padded_batch = padded_batch[:] |
| |
| input_data_batch = padded_batch["input_ids"] |
| attn_msk_batch = padded_batch["attention_mask"] |
| label_batch = padded_batch[label_name] |
| with torch.no_grad(): |
| outputs = model( |
| input_ids=input_data_batch.to("cuda"), |
| attention_mask=attn_msk_batch.to("cuda"), |
| labels=label_batch.to("cuda"), |
| ) |
| predict_logits += [torch.squeeze(outputs.logits.to("cpu"))] |
| predict_labels += [torch.squeeze(label_batch.to("cpu"))] |
|
|
| enable_progress_bar() |
| logits_by_cell = torch.cat(predict_logits) |
| last_dim = len(logits_by_cell.shape) - 1 |
| all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[last_dim]) |
| labels_by_cell = torch.cat(predict_labels) |
| all_labels = torch.flatten(labels_by_cell) |
| logit_label_paired = [ |
| item |
| for item in list(zip(all_logits.tolist(), all_labels.tolist())) |
| if item[1] != -100 |
| ] |
| y_pred = [vote(item[0]) for item in logit_label_paired] |
| y_true = [item[1] for item in logit_label_paired] |
| logits_list = [item[0] for item in logit_label_paired] |
|
|
| return y_pred, y_true, logits_list, predict_metadata_all |
|
|
|
|
| def get_metrics(y_pred, y_true, logits_list, num_classes, labels): |
| conf_mat = confusion_matrix(y_true, y_pred, labels=list(labels)) |
| macro_f1 = f1_score(y_true, y_pred, average="macro") |
| acc = accuracy_score(y_true, y_pred) |
| roc_metrics = None |
| if num_classes == 2: |
| y_score = [py_softmax(item)[1] for item in logits_list] |
| fpr, tpr, _ = roc_curve(y_true, y_score) |
| mean_fpr = np.linspace(0, 1, 100) |
| interp_tpr = np.interp(mean_fpr, fpr, tpr) |
| interp_tpr[0] = 0.0 |
| tpr_wt = len(tpr) |
| roc_auc = auc(fpr, tpr) |
| roc_metrics = { |
| "fpr": fpr, |
| "tpr": tpr, |
| "interp_tpr": interp_tpr, |
| "auc": roc_auc, |
| "tpr_wt": tpr_wt, |
| } |
| return conf_mat, macro_f1, acc, roc_metrics |
|
|
|
|
| |
| def get_cross_valid_roc_metrics(all_tpr, all_roc_auc, all_tpr_wt): |
| wts = [count / sum(all_tpr_wt) for count in all_tpr_wt] |
| all_weighted_tpr = [a * b for a, b in zip(all_tpr, wts)] |
| mean_tpr = np.sum(all_weighted_tpr, axis=0) |
| mean_tpr[-1] = 1.0 |
| all_weighted_roc_auc = [a * b for a, b in zip(all_roc_auc, wts)] |
| roc_auc = np.sum(all_weighted_roc_auc) |
| roc_auc_sd = math.sqrt(np.average((all_roc_auc - roc_auc) ** 2, weights=wts)) |
| return mean_tpr, roc_auc, roc_auc_sd |
|
|
|
|
| |
| def plot_ROC(roc_metric_dict, model_style_dict, title, output_dir, output_prefix): |
| fig = plt.figure() |
| fig.set_size_inches(10, 8) |
| sns.set(font_scale=2) |
| sns.set_style("white") |
| lw = 3 |
| for model_name in roc_metric_dict.keys(): |
| mean_fpr = roc_metric_dict[model_name]["mean_fpr"] |
| mean_tpr = roc_metric_dict[model_name]["mean_tpr"] |
| color = model_style_dict[model_name]["color"] |
| linestyle = model_style_dict[model_name]["linestyle"] |
| if "roc_auc" not in roc_metric_dict[model_name].keys(): |
| all_roc_auc = roc_metric_dict[model_name]["all_roc_auc"] |
| label = f"{model_name} (AUC {all_roc_auc:0.2f})" |
| else: |
| roc_auc = roc_metric_dict[model_name]["roc_auc"] |
| roc_auc_sd = roc_metric_dict[model_name]["roc_auc_sd"] |
| if len(roc_metric_dict[model_name]["all_roc_auc"]) > 1: |
| label = f"{model_name} (AUC {roc_auc:0.2f} $\pm$ {roc_auc_sd:0.2f})" |
| else: |
| label = f"{model_name} (AUC {roc_auc:0.2f})" |
| plt.plot( |
| mean_fpr, mean_tpr, color=color, linestyle=linestyle, lw=lw, label=label |
| ) |
|
|
| plt.plot([0, 1], [0, 1], color="black", lw=lw, linestyle="--") |
| plt.xlim([0.0, 1.0]) |
| plt.ylim([0.0, 1.05]) |
| plt.xlabel("False Positive Rate") |
| plt.ylabel("True Positive Rate") |
| plt.title(title) |
| plt.legend(loc="lower right") |
|
|
| output_file = (Path(output_dir) / f"{output_prefix}_roc").with_suffix(".pdf") |
| plt.savefig(output_file, bbox_inches="tight") |
| plt.show() |
|
|
|
|
| |
| def plot_confusion_matrix( |
| conf_mat_df, title, output_dir, output_prefix, custom_class_order |
| ): |
| fig = plt.figure() |
| fig.set_size_inches(10, 10) |
| sns.set(font_scale=1) |
| sns.set_style("whitegrid", {"axes.grid": False}) |
| if custom_class_order is not None: |
| conf_mat_df = conf_mat_df.reindex( |
| index=custom_class_order, columns=custom_class_order |
| ) |
| display_labels = generate_display_labels(conf_mat_df) |
| conf_mat = preprocessing.normalize(conf_mat_df.to_numpy(), norm="l1") |
| display = ConfusionMatrixDisplay( |
| confusion_matrix=conf_mat, display_labels=display_labels |
| ) |
| display.plot(cmap="Blues", values_format=".2g") |
| plt.title(title) |
| plt.show() |
|
|
| output_file = (Path(output_dir) / f"{output_prefix}_conf_mat").with_suffix(".pdf") |
| display.figure_.savefig(output_file, bbox_inches="tight") |
|
|
|
|
| def generate_display_labels(conf_mat_df): |
| display_labels = [] |
| i = 0 |
| for label in conf_mat_df.index: |
| display_labels += [f"{label}\nn={conf_mat_df.iloc[i,:].sum():.0f}"] |
| i = i + 1 |
| return display_labels |
|
|
|
|
| def plot_predictions(predictions_df, title, output_dir, output_prefix, kwargs_dict): |
| sns.set(font_scale=2) |
| plt.figure(figsize=(10, 10), dpi=150) |
| label_colors, label_color_dict = make_colorbar(predictions_df, "true") |
| predictions_df = predictions_df.drop(columns=["true"]) |
| predict_colors_list = [label_color_dict[label] for label in predictions_df.columns] |
| predict_label_list = [label for label in predictions_df.columns] |
| predict_colors = pd.DataFrame( |
| pd.Series(predict_colors_list, index=predict_label_list), columns=["predicted"] |
| ) |
|
|
| default_kwargs_dict = { |
| "row_cluster": False, |
| "col_cluster": False, |
| "row_colors": label_colors, |
| "col_colors": predict_colors, |
| "linewidths": 0, |
| "xticklabels": False, |
| "yticklabels": False, |
| "center": 0, |
| "cmap": "vlag", |
| } |
|
|
| if kwargs_dict is not None: |
| default_kwargs_dict.update(kwargs_dict) |
| g = sns.clustermap(predictions_df, **default_kwargs_dict) |
|
|
| plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right") |
|
|
| for label_color in list(label_color_dict.keys()): |
| g.ax_col_dendrogram.bar( |
| 0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0 |
| ) |
|
|
| g.ax_col_dendrogram.legend( |
| title=f"{title}", |
| loc="lower center", |
| ncol=4, |
| bbox_to_anchor=(0.5, 1), |
| facecolor="white", |
| ) |
|
|
| output_file = (Path(output_dir) / f"{output_prefix}_pred").with_suffix(".pdf") |
| plt.savefig(output_file, bbox_inches="tight") |
|
|