|
|
import os |
|
|
from tqdm.auto import tqdm, trange |
|
|
GPU_NUMBER = [0] |
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER]) |
|
|
os.environ["NCCL_DEBUG"] = "INFO" |
|
|
|
|
|
|
|
|
from collections import Counter |
|
|
import seaborn as sns; sns.set() |
|
|
from datasets import load_from_disk |
|
|
from sklearn.metrics import accuracy_score, f1_score |
|
|
from transformers import Trainer |
|
|
from transformers.training_args import TrainingArguments |
|
|
import pandas as pd |
|
|
from datasets.utils.logging import disable_progress_bar, enable_progress_bar |
|
|
from sklearn import preprocessing |
|
|
from sklearn.metrics import ( |
|
|
ConfusionMatrixDisplay, |
|
|
accuracy_score, |
|
|
auc, |
|
|
confusion_matrix, |
|
|
f1_score, |
|
|
roc_curve, |
|
|
) |
|
|
from pathlib import Path |
|
|
|
|
|
import sys |
|
|
|
|
|
from geneformer import DataCollatorForCellClassification |
|
|
from datasets import load_from_disk |
|
|
import sys |
|
|
from tqdm.notebook import tqdm |
|
|
import seaborn as sns |
|
|
import matplotlib.pyplot as plt |
|
|
from geneformer.pretrainer import token_dictionary |
|
|
import datetime |
|
|
import time |
|
|
import pickle |
|
|
import random |
|
|
import subprocess |
|
|
import numpy as np |
|
|
import pytz |
|
|
import torch |
|
|
from datasets import load_from_disk, Dataset |
|
|
from transformers import (BertConfig, BertForMaskedLM, TrainingArguments, TrainerCallback, |
|
|
Trainer, BertModel, BertPreTrainedModel, BertForSequenceClassification, BertForTokenClassification) |
|
|
from geneformer import GeneformerPretrainer |
|
|
from torch import Tensor |
|
|
from transformers.modeling_outputs import MaskedLMOutput |
|
|
from transformers.models.bert.modeling_bert import BertLMPredictionHead, BertOnlyMLMHead, BertPredictionHeadTransform |
|
|
from transformers.activations import ACT2FN |
|
|
from typing import List, Optional, Tuple, Union |
|
|
import torch.nn.functional as F |
|
|
|
|
|
macro_f1_list = [] |
|
|
acc_list = [] |
|
|
|
|
|
iter_step = 2 |
|
|
|
|
|
class CustomBertForMaskedLM(BertPreTrainedModel): |
|
|
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] |
|
|
_tied_weights_keys = ["decoder.weight", "bert.embeddings.word_embeddings.weight"] |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.bert = BertModel(config, add_pooling_layer=False) |
|
|
self.transform = BertPredictionHeadTransform(config) |
|
|
|
|
|
self.decoder = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.bias = torch.nn.Parameter(torch.zeros(config.vocab_size)) |
|
|
|
|
|
|
|
|
self.init_weights() |
|
|
|
|
|
|
|
|
self.tie_weights() |
|
|
|
|
|
|
|
|
|
|
|
def tie_weights(self): |
|
|
""" |
|
|
Ties the weights between the input embeddings and output decoder weights. |
|
|
""" |
|
|
self.decoder.weight = self.bert.embeddings.word_embeddings.weight |
|
|
|
|
|
def probability_convert(self, probs: Tensor, input_ids: Tensor, labels: Tensor) -> Tensor: |
|
|
device = probs.device |
|
|
batch_size, seq_length, vocab_size = probs.size() |
|
|
_, input_seq_length = input_ids.size() |
|
|
non_mask = labels == -100 |
|
|
non_mask_indices = non_mask.nonzero(as_tuple=True) |
|
|
known_gene_indices = input_ids[non_mask] |
|
|
|
|
|
|
|
|
zeros = torch.zeros((batch_size, 1, vocab_size), device=device) |
|
|
zeros[non_mask_indices[0], 0, known_gene_indices] = 1.0 |
|
|
probs_shifted = torch.cat((zeros, probs[:, :-1, :]), dim=1) |
|
|
inv_probs_shifted = 1 - probs_shifted |
|
|
|
|
|
|
|
|
cumprod_inv_probs = torch.cumprod(inv_probs_shifted, dim=1) |
|
|
modified_probs = probs * cumprod_inv_probs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
normalized_probs = modified_probs.sum(dim=-1, keepdim=True).clamp(min=1e-18) |
|
|
modified_probs = modified_probs / normalized_probs |
|
|
|
|
|
return modified_probs |
|
|
|
|
|
def assign_known_gene_probs(self, probs: Tensor, input_ids: Tensor, labels: Tensor) -> Tensor: |
|
|
|
|
|
device = probs.device |
|
|
batch_size, seq_length, vocab_size = probs.size() |
|
|
_, input_seq_length = input_ids.size() |
|
|
|
|
|
|
|
|
truncated_labels = labels[:, :input_seq_length] |
|
|
|
|
|
non_mask = truncated_labels == -100 |
|
|
non_mask_indices = non_mask.nonzero(as_tuple=True) |
|
|
|
|
|
ones = torch.ones((batch_size, seq_length, vocab_size), device=device) |
|
|
zeros = torch.zeros((batch_size, seq_length, vocab_size), device=device) |
|
|
|
|
|
known_gene_indices = input_ids[non_mask] |
|
|
|
|
|
ones[non_mask_indices[0], non_mask_indices[1], :] = 0.0 |
|
|
zeros[non_mask_indices[0], non_mask_indices[1], known_gene_indices] = 1.0 |
|
|
|
|
|
|
|
|
modified_probs = probs * ones |
|
|
modified_probs = modified_probs + zeros |
|
|
|
|
|
|
|
|
modified_probs = modified_probs / modified_probs.sum(dim=-1, keepdim=True).clamp(min=1e-18) |
|
|
|
|
|
return modified_probs |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Tensor | None = None, |
|
|
attention_mask: Tensor | None = None, |
|
|
token_type_ids: Tensor | None = None, |
|
|
position_ids: Tensor | None = None, |
|
|
head_mask: Tensor | None = None, |
|
|
inputs_embeds: Tensor | None = None, |
|
|
encoder_hidden_states: Tensor | None = None, |
|
|
encoder_attention_mask: Tensor | None = None, |
|
|
labels: Tensor | None = None, |
|
|
output_attentions: bool | None = None, |
|
|
output_hidden_states: bool | None = None, |
|
|
return_dict: bool | None = None) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: |
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
position_ids=position_ids, |
|
|
head_mask=head_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
hidden_states = outputs[0] |
|
|
hidden_transform = self.transform(hidden_states) |
|
|
logits = self.decoder(hidden_transform) + self.bias |
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
|
|
|
|
|
|
probs = self.assign_known_gene_probs(probs, input_ids, labels) |
|
|
convert_probs = self.probability_convert(probs, input_ids, labels) |
|
|
assigned_probs = self.assign_known_gene_probs(convert_probs, input_ids, labels) |
|
|
|
|
|
masked_lm_loss = None |
|
|
if labels is not None: |
|
|
probs_flat = assigned_probs.view(-1, self.config.vocab_size) |
|
|
labels_flat = labels.view(-1) |
|
|
mask = (labels != -100).float().view(-1) |
|
|
|
|
|
|
|
|
masked_lm_loss = -torch.log(torch.clamp(probs_flat[torch.arange(len(labels_flat)), labels_flat], min=1e-18)) * mask |
|
|
masked_lm_loss = masked_lm_loss.sum() / mask.sum() |
|
|
|
|
|
else: |
|
|
loss = None |
|
|
|
|
|
if not return_dict: |
|
|
output = (assigned_probs,) + outputs[2:] |
|
|
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output |
|
|
|
|
|
return MaskedLMOutput( |
|
|
loss=masked_lm_loss, |
|
|
logits=assigned_probs, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): |
|
|
input_shape = input_ids.shape |
|
|
effective_batch_size = input_shape[0] |
|
|
|
|
|
|
|
|
if self.config.pad_token_id is None: |
|
|
raise ValueError("The PAD token should be defined for generation") |
|
|
|
|
|
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) |
|
|
dummy_token = torch.full( |
|
|
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device |
|
|
) |
|
|
input_ids = torch.cat([input_ids, dummy_token], dim=1) |
|
|
|
|
|
return {"input_ids": input_ids, "attention_mask": attention_mask} |
|
|
|
|
|
def prepare_data( |
|
|
input_data_file, |
|
|
output_directory, |
|
|
output_prefix, |
|
|
split_id_dict=None, |
|
|
test_size=None, |
|
|
attr_to_split=None, |
|
|
attr_to_balance=None, |
|
|
max_trials=100, |
|
|
pval_threshold=0.1, |
|
|
): |
|
|
""" |
|
|
Prepare data for cell state or gene classification. |
|
|
|
|
|
**Parameters** |
|
|
|
|
|
input_data_file : Path |
|
|
| Path to directory containing .dataset input |
|
|
output_directory : Path |
|
|
| Path to directory where prepared data will be saved |
|
|
output_prefix : str |
|
|
| Prefix for output file |
|
|
split_id_dict : None, dict |
|
|
| Dictionary of IDs for train and test splits |
|
|
| Three-item dictionary with keys: attr_key, train, test |
|
|
| attr_key: key specifying name of column in .dataset that contains the IDs for the data splits |
|
|
| train: list of IDs in the attr_key column to include in the train split |
|
|
| test: list of IDs in the attr_key column to include in the test split |
|
|
| For example: {"attr_key": "individual", |
|
|
| "train": ["patient1", "patient2", "patient3", "patient4"], |
|
|
| "test": ["patient5", "patient6"]} |
|
|
test_size : None, float |
|
|
| Proportion of data to be saved separately and held out for test set |
|
|
| (e.g. 0.2 if intending hold out 20%) |
|
|
| If None, will inherit from split_sizes["test"] from Classifier |
|
|
| The training set will be further split to train / validation in self.validate |
|
|
| Note: only available for CellClassifiers |
|
|
attr_to_split : None, str |
|
|
| Key for attribute on which to split data while balancing potential confounders |
|
|
| e.g. "patient_id" for splitting by patient while balancing other characteristics |
|
|
| Note: only available for CellClassifiers |
|
|
attr_to_balance : None, list |
|
|
| List of attribute keys on which to balance data while splitting on attr_to_split |
|
|
| e.g. ["age", "sex"] for balancing these characteristics while splitting by patient |
|
|
| Note: only available for CellClassifiers |
|
|
max_trials : None, int |
|
|
| Maximum number of trials of random splitting to try to achieve balanced other attributes |
|
|
| If no split is found without significant (p<0.05) differences in other attributes, will select best |
|
|
| Note: only available for CellClassifiers |
|
|
pval_threshold : None, float |
|
|
| P-value threshold to use for attribute balancing across splits |
|
|
| E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance |
|
|
""" |
|
|
|
|
|
if test_size is None: |
|
|
test_size = oos_test_size |
|
|
|
|
|
|
|
|
data = load_and_filter(filter_data, nproc, input_data_file) |
|
|
|
|
|
if classifier == "cell": |
|
|
if "label" in data.features: |
|
|
logger.error( |
|
|
"Column name 'label' must be reserved for class IDs. Please rename column." |
|
|
) |
|
|
raise |
|
|
elif classifier == "gene": |
|
|
if "labels" in data.features: |
|
|
logger.error( |
|
|
"Column name 'labels' must be reserved for class IDs. Please rename column." |
|
|
) |
|
|
raise |
|
|
|
|
|
if (attr_to_split is not None) and (attr_to_balance is None): |
|
|
logger.error( |
|
|
"Splitting by attribute while balancing confounders requires both attr_to_split and attr_to_balance to be defined." |
|
|
) |
|
|
raise |
|
|
|
|
|
if not isinstance(attr_to_balance, list): |
|
|
attr_to_balance = [attr_to_balance] |
|
|
|
|
|
if classifier == "cell": |
|
|
|
|
|
data = remove_rare( |
|
|
data, rare_threshold, cell_state_dict["state_key"], nproc |
|
|
) |
|
|
|
|
|
data = downsample_and_shuffle( |
|
|
data, max_ncells, None, cell_state_dict |
|
|
) |
|
|
|
|
|
data = rename_cols(data, cell_state_dict["state_key"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data, id_class_dict = label_classes( |
|
|
classifier, data, None, nproc |
|
|
) |
|
|
|
|
|
|
|
|
id_class_output_path = ( |
|
|
Path(output_directory) / f"{output_prefix}_id_class_dict" |
|
|
).with_suffix(".pkl") |
|
|
with open(id_class_output_path, "wb") as f: |
|
|
pickle.dump(id_class_dict, f) |
|
|
|
|
|
if split_id_dict is not None: |
|
|
data_dict = dict() |
|
|
data_dict["train"] = filter_by_dict( |
|
|
data, {split_id_dict["attr_key"]: split_id_dict["train"]}, nproc |
|
|
) |
|
|
data_dict["test"] = filter_by_dict( |
|
|
data, {split_id_dict["attr_key"]: split_id_dict["test"]}, nproc |
|
|
) |
|
|
train_data_output_path = ( |
|
|
Path(output_directory) / f"{output_prefix}_labeled_train" |
|
|
).with_suffix(".dataset") |
|
|
test_data_output_path = ( |
|
|
Path(output_directory) / f"{output_prefix}_labeled_test" |
|
|
).with_suffix(".dataset") |
|
|
data_dict["train"].save_to_disk(str(train_data_output_path)) |
|
|
data_dict["test"].save_to_disk(str(test_data_output_path)) |
|
|
elif (test_size is not None) and (classifier == "cell"): |
|
|
if 1 > test_size > 0: |
|
|
if attr_to_split is None: |
|
|
data_dict = data.train_test_split( |
|
|
test_size=test_size, |
|
|
stratify_by_column=None, |
|
|
seed=42, |
|
|
) |
|
|
train_data_output_path = ( |
|
|
Path(output_directory) / f"{output_prefix}_labeled_train" |
|
|
).with_suffix(".dataset") |
|
|
test_data_output_path = ( |
|
|
Path(output_directory) / f"{output_prefix}_labeled_test" |
|
|
).with_suffix(".dataset") |
|
|
data_dict["train"].save_to_disk(str(train_data_output_path)) |
|
|
data_dict["test"].save_to_disk(str(test_data_output_path)) |
|
|
else: |
|
|
data_dict, balance_df = cu.balance_attr_splits( |
|
|
data, |
|
|
attr_to_split, |
|
|
attr_to_balance, |
|
|
test_size, |
|
|
max_trials, |
|
|
pval_threshold, |
|
|
cell_state_dict["state_key"], |
|
|
nproc, |
|
|
) |
|
|
balance_df.to_csv( |
|
|
f"{output_directory}/{output_prefix}_train_test_balance_df.csv" |
|
|
) |
|
|
train_data_output_path = ( |
|
|
Path(output_directory) / f"{output_prefix}_labeled_train" |
|
|
).with_suffix(".dataset") |
|
|
test_data_output_path = ( |
|
|
Path(output_directory) / f"{output_prefix}_labeled_test" |
|
|
).with_suffix(".dataset") |
|
|
data_dict["train"].save_to_disk(str(train_data_output_path)) |
|
|
data_dict["test"].save_to_disk(str(test_data_output_path)) |
|
|
else: |
|
|
data_output_path = ( |
|
|
Path(output_directory) / f"{output_prefix}_labeled" |
|
|
).with_suffix(".dataset") |
|
|
data.save_to_disk(str(data_output_path)) |
|
|
print(data_output_path) |
|
|
else: |
|
|
data_output_path = ( |
|
|
Path(output_directory) / f"{output_prefix}_labeled" |
|
|
).with_suffix(".dataset") |
|
|
data.save_to_disk(str(data_output_path)) |
|
|
|
|
|
def load_and_filter(filter_data, nproc, input_data_file): |
|
|
data = load_from_disk(input_data_file) |
|
|
if filter_data is not None: |
|
|
data = filter_by_dict(data, filter_data, nproc) |
|
|
return data |
|
|
|
|
|
def get_num_classes(id_class_dict): |
|
|
return len(set(id_class_dict.values())) |
|
|
|
|
|
def filter_by_dict(data, filter_data, nproc): |
|
|
for key, value in filter_data.items(): |
|
|
|
|
|
def filter_data_by_criteria(example): |
|
|
return example[key] in value |
|
|
|
|
|
data = data.filter(filter_data_by_criteria, num_proc=nproc) |
|
|
if len(data) == 0: |
|
|
logger.error("No cells remain after filtering. Check filtering criteria.") |
|
|
raise |
|
|
return data |
|
|
def remove_rare(data, rare_threshold, label, nproc): |
|
|
if rare_threshold > 0: |
|
|
total_cells = len(data) |
|
|
label_counter = Counter(data[label]) |
|
|
nonrare_label_dict = { |
|
|
label: [k for k, v in label_counter if (v / total_cells) > rare_threshold] |
|
|
} |
|
|
data = filter_by_dict(data, nonrare_label_dict, nproc) |
|
|
return data |
|
|
def downsample_and_shuffle(data, max_ncells, max_ncells_per_class, cell_state_dict): |
|
|
data = data.shuffle(seed=42) |
|
|
num_cells = len(data) |
|
|
|
|
|
if max_ncells is not None: |
|
|
if num_cells > max_ncells: |
|
|
data = data.select([i for i in range(max_ncells)]) |
|
|
if max_ncells_per_class is not None: |
|
|
class_labels = data[cell_state_dict["state_key"]] |
|
|
random.seed(42) |
|
|
subsample_indices = subsample_by_class(class_labels, max_ncells_per_class) |
|
|
data = data.select(subsample_indices) |
|
|
return data |
|
|
def rename_cols(data, state_key): |
|
|
data = data.rename_column(state_key, "label") |
|
|
return data |
|
|
def label_classes(classifier, data, gene_class_dict, nproc): |
|
|
if classifier == "cell": |
|
|
label_set = set(data["label"]) |
|
|
elif classifier == "gene": |
|
|
|
|
|
def if_contains_label(example): |
|
|
a = pu.flatten_list(gene_class_dict.values()) |
|
|
b = example["input_ids"] |
|
|
return not set(a).isdisjoint(b) |
|
|
|
|
|
data = data.filter(if_contains_label, num_proc=nproc) |
|
|
label_set = gene_class_dict.keys() |
|
|
|
|
|
if len(data) == 0: |
|
|
logger.error( |
|
|
"No cells remain after filtering for target genes. Check target gene list." |
|
|
) |
|
|
raise |
|
|
|
|
|
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))])) |
|
|
id_class_dict = {v: k for k, v in class_id_dict.items()} |
|
|
|
|
|
def classes_to_ids(example): |
|
|
if classifier == "cell": |
|
|
example["label"] = class_id_dict[example["label"]] |
|
|
elif classifier == "gene": |
|
|
example["labels"] = label_gene_classes( |
|
|
example, class_id_dict, gene_class_dict |
|
|
) |
|
|
return example |
|
|
|
|
|
data = data.map(classes_to_ids, num_proc=nproc) |
|
|
return data, id_class_dict |
|
|
|
|
|
def train_classifier( |
|
|
model_directory, |
|
|
num_classes, |
|
|
train_data, |
|
|
eval_data, |
|
|
output_directory, |
|
|
predict=False, |
|
|
classifier='cell', |
|
|
no_eval=False, |
|
|
quantize = False, |
|
|
freeze_layers=2, |
|
|
): |
|
|
""" |
|
|
Fine-tune model for cell state or gene classification. |
|
|
|
|
|
**Parameters** |
|
|
|
|
|
model_directory : Path |
|
|
| Path to directory containing model |
|
|
num_classes : int |
|
|
| Number of classes for classifier |
|
|
train_data : Dataset |
|
|
| Loaded training .dataset input |
|
|
| For cell classifier, labels in column "label". |
|
|
| For gene classifier, labels in column "labels". |
|
|
eval_data : None, Dataset |
|
|
| (Optional) Loaded evaluation .dataset input |
|
|
| For cell classifier, labels in column "label". |
|
|
| For gene classifier, labels in column "labels". |
|
|
output_directory : Path |
|
|
| Path to directory where fine-tuned model will be saved |
|
|
predict : bool |
|
|
| Whether or not to save eval predictions from trainer |
|
|
""" |
|
|
|
|
|
|
|
|
train_data, eval_data = validate_and_clean_cols( |
|
|
train_data, eval_data, classifier |
|
|
) |
|
|
|
|
|
if (no_eval is True) and (eval_data is not None): |
|
|
logger.warning( |
|
|
"no_eval set to True; model will be trained without evaluation." |
|
|
) |
|
|
eval_data = None |
|
|
|
|
|
if (classifier == "gene") and (predict is True): |
|
|
logger.warning( |
|
|
"Predictions during training not currently available for gene classifiers; setting predict to False." |
|
|
) |
|
|
predict = False |
|
|
|
|
|
|
|
|
saved_model_test = os.path.join(output_directory, "pytorch_model.bin") |
|
|
if os.path.isfile(saved_model_test) is True: |
|
|
logger.error("Model already saved to this designated output directory.") |
|
|
raise |
|
|
|
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
model = load_model( |
|
|
"CellClassifier", |
|
|
num_classes, |
|
|
model_directory, |
|
|
"train", |
|
|
quantize=quantize, |
|
|
) |
|
|
|
|
|
pretrained_model = CustomBertForMaskedLM.from_pretrained(model_directory) |
|
|
|
|
|
pretrained_word_embeddings = pretrained_model.bert.embeddings.word_embeddings.weight.clone() |
|
|
model.bert.embeddings.word_embeddings.load_state_dict({"weight": pretrained_word_embeddings}) |
|
|
|
|
|
def_training_args, def_freeze_layers = get_default_train_args( |
|
|
model, classifier, train_data, output_directory |
|
|
) |
|
|
|
|
|
if training_args is not None: |
|
|
def_training_args.update(training_args) |
|
|
logging_steps = round( |
|
|
len(train_data) / def_training_args["per_device_train_batch_size"] / 10 |
|
|
) |
|
|
def_training_args["logging_steps"] = logging_steps |
|
|
def_training_args["output_dir"] = output_directory |
|
|
if eval_data is None: |
|
|
def_training_args["evaluation_strategy"] = "no" |
|
|
def_training_args["load_best_model_at_end"] = False |
|
|
training_args_init = TrainingArguments(**def_training_args) |
|
|
|
|
|
if freeze_layers is not None: |
|
|
def_freeze_layers = freeze_layers |
|
|
|
|
|
if def_freeze_layers > 0: |
|
|
modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers] |
|
|
for module in modules_to_freeze: |
|
|
for param in module.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
|
|
|
if classifier == "cell": |
|
|
data_collator = DataCollatorForCellClassification() |
|
|
elif self.classifier == "gene": |
|
|
data_collator = DataCollatorForGeneClassification() |
|
|
|
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=training_args_init, |
|
|
data_collator=data_collator, |
|
|
train_dataset=train_data, |
|
|
eval_dataset=eval_data, |
|
|
compute_metrics=compute_metrics, |
|
|
) |
|
|
|
|
|
|
|
|
trainer.train() |
|
|
trainer.save_model(output_directory) |
|
|
if predict is True: |
|
|
|
|
|
predictions = trainer.predict(eval_data) |
|
|
prediction_output_path = f"{output_directory}/predictions.pkl" |
|
|
with open(prediction_output_path, "wb") as f: |
|
|
pickle.dump(predictions, f) |
|
|
trainer.save_metrics("eval", predictions.metrics) |
|
|
return trainer |
|
|
|
|
|
def validate_and_clean_cols(train_data, eval_data, classifier): |
|
|
|
|
|
if classifier == "cell": |
|
|
label_col = "label" |
|
|
elif classifier == "gene": |
|
|
label_col = "labels" |
|
|
|
|
|
cols_to_keep = [label_col] + ["input_ids", "length"] |
|
|
if label_col not in train_data.column_names: |
|
|
logger.error(f"train_data must contain column {label_col} with class labels.") |
|
|
raise |
|
|
else: |
|
|
train_data = remove_cols(train_data, cols_to_keep) |
|
|
|
|
|
if eval_data is not None: |
|
|
if label_col not in eval_data.column_names: |
|
|
logger.error( |
|
|
f"eval_data must contain column {label_col} with class labels." |
|
|
) |
|
|
raise |
|
|
else: |
|
|
eval_data = remove_cols(eval_data, cols_to_keep) |
|
|
return train_data, eval_data |
|
|
|
|
|
def remove_cols(data, cols_to_keep): |
|
|
other_cols = list(data.features.keys()) |
|
|
other_cols = [ele for ele in other_cols if ele not in cols_to_keep] |
|
|
data = data.remove_columns(other_cols) |
|
|
return data |
|
|
|
|
|
def load_model(model_type, num_classes, model_directory, mode, quantize=False): |
|
|
if model_type == "MTLCellClassifier-Quantized": |
|
|
model_type = "MTLCellClassifier" |
|
|
quantize = True |
|
|
|
|
|
output_hidden_states = (mode == "eval") |
|
|
|
|
|
|
|
|
if quantize: |
|
|
if model_type == "MTLCellClassifier": |
|
|
quantize_config = BitsAndBytesConfig(load_in_8bit=True) |
|
|
peft_config = None |
|
|
else: |
|
|
quantize_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_use_double_quant=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
|
) |
|
|
peft_config = LoraConfig( |
|
|
lora_alpha=128, |
|
|
lora_dropout=0.1, |
|
|
r=64, |
|
|
bias="none", |
|
|
task_type="TokenClassification", |
|
|
) |
|
|
else: |
|
|
quantize_config = None |
|
|
peft_config = None |
|
|
|
|
|
|
|
|
model_classes = { |
|
|
"Pretrained": BertForMaskedLM, |
|
|
"GeneClassifier": BertForTokenClassification, |
|
|
"CellClassifier": BertForSequenceClassification, |
|
|
"MTLCellClassifier": BertForMaskedLM |
|
|
} |
|
|
|
|
|
model_class = model_classes.get(model_type) |
|
|
if not model_class: |
|
|
raise ValueError(f"Unknown model type: {model_type}") |
|
|
|
|
|
|
|
|
model_args = { |
|
|
"pretrained_model_name_or_path": model_directory, |
|
|
"output_hidden_states": output_hidden_states, |
|
|
"output_attentions": False, |
|
|
} |
|
|
|
|
|
if model_type != "Pretrained": |
|
|
model_args["num_labels"] = num_classes |
|
|
|
|
|
if quantize_config: |
|
|
model_args["quantization_config"] = quantize_config |
|
|
|
|
|
|
|
|
model = model_class.from_pretrained(**model_args) |
|
|
|
|
|
|
|
|
if mode == "eval": |
|
|
model.eval() |
|
|
|
|
|
|
|
|
if not quantize: |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = model.to(device) |
|
|
elif peft_config: |
|
|
|
|
|
model.enable_input_require_grads() |
|
|
model = get_peft_model(model, peft_config) |
|
|
|
|
|
return model |
|
|
|
|
|
def get_default_train_args(model, classifier, data, output_dir): |
|
|
num_layers = quant_layers(model) |
|
|
freeze_layers_get = 0 |
|
|
batch_size = 12 |
|
|
if classifier == "cell": |
|
|
epochs = 10 |
|
|
evaluation_strategy = "epoch" |
|
|
load_best_model_at_end = True |
|
|
else: |
|
|
epochs = 1 |
|
|
evaluation_strategy = "no" |
|
|
load_best_model_at_end = False |
|
|
|
|
|
if num_layers == 6: |
|
|
default_training_args = { |
|
|
"learning_rate": 5e-5, |
|
|
"lr_scheduler_type": "linear", |
|
|
"warmup_steps": 500, |
|
|
"per_device_train_batch_size": batch_size, |
|
|
"per_device_eval_batch_size": batch_size, |
|
|
} |
|
|
else: |
|
|
default_training_args = { |
|
|
"per_device_train_batch_size": batch_size, |
|
|
"per_device_eval_batch_size": batch_size, |
|
|
} |
|
|
|
|
|
training_args = { |
|
|
"num_train_epochs": epochs, |
|
|
"do_train": True, |
|
|
"do_eval": True, |
|
|
"evaluation_strategy": evaluation_strategy, |
|
|
"logging_steps": np.floor(len(data) / batch_size / 8), |
|
|
"save_strategy": "epoch", |
|
|
"group_by_length": False, |
|
|
"length_column_name": "length", |
|
|
"disable_tqdm": False, |
|
|
"weight_decay": 0.001, |
|
|
"load_best_model_at_end": load_best_model_at_end, |
|
|
} |
|
|
training_args.update(default_training_args) |
|
|
|
|
|
return training_args, freeze_layers_get |
|
|
|
|
|
def quant_layers(model): |
|
|
layer_nums = [] |
|
|
for name, parameter in model.named_parameters(): |
|
|
if "layer" in name: |
|
|
layer_nums += [int(name.split("layer.")[1].split(".")[0])] |
|
|
return int(max(layer_nums)) + 1 |
|
|
|
|
|
def compute_metrics(pred): |
|
|
labels = pred.label_ids |
|
|
preds = pred.predictions.argmax(-1) |
|
|
|
|
|
acc = accuracy_score(labels, preds) |
|
|
macro_f1 = f1_score(labels, preds, average='macro') |
|
|
weighted_f1 = f1_score(labels, preds, average='weighted') |
|
|
return { |
|
|
'accuracy': acc, |
|
|
'macro_f1': macro_f1, |
|
|
'weighted_f1': weighted_f1 |
|
|
} |
|
|
def evaluate_model( |
|
|
model, |
|
|
num_classes, |
|
|
id_class_dict, |
|
|
eval_data, |
|
|
predict=False, |
|
|
output_directory=None, |
|
|
output_prefix=None, |
|
|
): |
|
|
""" |
|
|
Evaluate the fine-tuned model. |
|
|
|
|
|
**Parameters** |
|
|
|
|
|
model : nn.Module |
|
|
| Loaded fine-tuned model (e.g. trainer.model) |
|
|
num_classes : int |
|
|
| Number of classes for classifier |
|
|
id_class_dict : dict |
|
|
| Loaded _id_class_dict.pkl previously prepared by Classifier.prepare_data |
|
|
| (dictionary of format: numerical IDs: class_labels) |
|
|
eval_data : Dataset |
|
|
| Loaded evaluation .dataset input |
|
|
predict : bool |
|
|
| Whether or not to save eval predictions |
|
|
output_directory : Path |
|
|
| Path to directory where eval data will be saved |
|
|
output_prefix : str |
|
|
| Prefix for output files |
|
|
""" |
|
|
|
|
|
|
|
|
labels = id_class_dict.keys() |
|
|
y_pred, y_true, logits_list = classifier_predict( |
|
|
model, classifier, eval_data, 100 |
|
|
) |
|
|
conf_mat, macro_f1, acc, roc_metrics = get_metrics( |
|
|
y_pred, y_true, logits_list, num_classes, labels |
|
|
) |
|
|
if predict is True: |
|
|
pred_dict = { |
|
|
"pred_ids": y_pred, |
|
|
"label_ids": y_true, |
|
|
"predictions": logits_list, |
|
|
} |
|
|
pred_dict_output_path = ( |
|
|
Path(output_directory) / f"{output_prefix}_pred_dict" |
|
|
).with_suffix(".pkl") |
|
|
with open(pred_dict_output_path, "wb") as f: |
|
|
pickle.dump(pred_dict, f) |
|
|
return { |
|
|
"conf_mat": conf_mat, |
|
|
"macro_f1": macro_f1, |
|
|
"acc": acc, |
|
|
"roc_metrics": roc_metrics, |
|
|
} |
|
|
|
|
|
def classifier_predict(model, classifier_type, evalset, forward_batch_size): |
|
|
if classifier_type == "gene": |
|
|
label_name = "labels" |
|
|
elif classifier_type == "cell": |
|
|
label_name = "label" |
|
|
|
|
|
predict_logits = [] |
|
|
predict_labels = [] |
|
|
model.eval() |
|
|
|
|
|
|
|
|
evalset_len = len(evalset) |
|
|
max_divisible = find_largest_div(evalset_len, forward_batch_size) |
|
|
if len(evalset) - max_divisible == 1: |
|
|
evalset_len = max_divisible |
|
|
|
|
|
max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"]) |
|
|
|
|
|
disable_progress_bar() |
|
|
for i in trange(0, evalset_len, forward_batch_size): |
|
|
max_range = min(i + forward_batch_size, evalset_len) |
|
|
batch_evalset = evalset.select([i for i in range(i, max_range)]) |
|
|
padded_batch = preprocess_classifier_batch( |
|
|
batch_evalset, max_evalset_len, label_name |
|
|
) |
|
|
padded_batch.set_format(type="torch") |
|
|
|
|
|
input_data_batch = padded_batch["input_ids"] |
|
|
attn_msk_batch = padded_batch["attention_mask"] |
|
|
label_batch = padded_batch[label_name] |
|
|
with torch.no_grad(): |
|
|
outputs = model( |
|
|
input_ids=input_data_batch.to("cuda"), |
|
|
attention_mask=attn_msk_batch.to("cuda"), |
|
|
labels=label_batch.to("cuda"), |
|
|
) |
|
|
predict_logits += [torch.squeeze(outputs.logits.to("cpu"))] |
|
|
predict_labels += [torch.squeeze(label_batch.to("cpu"))] |
|
|
|
|
|
enable_progress_bar() |
|
|
logits_by_cell = torch.cat(predict_logits) |
|
|
last_dim = len(logits_by_cell.shape) - 1 |
|
|
all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[last_dim]) |
|
|
labels_by_cell = torch.cat(predict_labels) |
|
|
all_labels = torch.flatten(labels_by_cell) |
|
|
logit_label_paired = [ |
|
|
item |
|
|
for item in list(zip(all_logits.tolist(), all_labels.tolist())) |
|
|
if item[1] != -100 |
|
|
] |
|
|
y_pred = [vote(item[0]) for item in logit_label_paired] |
|
|
y_true = [item[1] for item in logit_label_paired] |
|
|
logits_list = [item[0] for item in logit_label_paired] |
|
|
return y_pred, y_true, logits_list |
|
|
|
|
|
def find_largest_div(N, K): |
|
|
rem = N % K |
|
|
if rem == 0: |
|
|
return N |
|
|
else: |
|
|
return N - rem |
|
|
def preprocess_classifier_batch(cell_batch, max_len, label_name): |
|
|
if max_len is None: |
|
|
max_len = max([len(i) for i in cell_batch["input_ids"]]) |
|
|
|
|
|
def pad_label_example(example): |
|
|
example[label_name] = np.pad( |
|
|
example[label_name], |
|
|
(0, max_len - len(example["input_ids"])), |
|
|
mode="constant", |
|
|
constant_values=-100, |
|
|
) |
|
|
example["input_ids"] = np.pad( |
|
|
example["input_ids"], |
|
|
(0, max_len - len(example["input_ids"])), |
|
|
mode="constant", |
|
|
constant_values=gene_token_dict.get("<pad>"), |
|
|
) |
|
|
example["attention_mask"] = ( |
|
|
example["input_ids"] != gene_token_dict.get("<pad>") |
|
|
).astype(int) |
|
|
return example |
|
|
|
|
|
padded_batch = cell_batch.map(pad_label_example) |
|
|
return padded_batch |
|
|
def vote(logit_list): |
|
|
m = max(logit_list) |
|
|
logit_list.index(m) |
|
|
indices = [i for i, x in enumerate(logit_list) if x == m] |
|
|
if len(indices) > 1: |
|
|
return "tie" |
|
|
else: |
|
|
return indices[0] |
|
|
def py_softmax(vector): |
|
|
e = np.exp(vector) |
|
|
return e / e.sum() |
|
|
def get_metrics(y_pred, y_true, logits_list, num_classes, labels): |
|
|
conf_mat = confusion_matrix(y_true, y_pred, labels=list(labels)) |
|
|
macro_f1 = f1_score(y_true, y_pred, average="macro") |
|
|
acc = accuracy_score(y_true, y_pred) |
|
|
roc_metrics = None |
|
|
if num_classes == 2: |
|
|
y_score = [py_softmax(item)[1] for item in logits_list] |
|
|
fpr, tpr, _ = roc_curve(y_true, y_score) |
|
|
mean_fpr = np.linspace(0, 1, 100) |
|
|
interp_tpr = np.interp(mean_fpr, fpr, tpr) |
|
|
interp_tpr[0] = 0.0 |
|
|
tpr_wt = len(tpr) |
|
|
roc_auc = auc(fpr, tpr) |
|
|
roc_metrics = { |
|
|
"fpr": fpr, |
|
|
"tpr": tpr, |
|
|
"interp_tpr": interp_tpr, |
|
|
"auc": roc_auc, |
|
|
"tpr_wt": tpr_wt, |
|
|
} |
|
|
return conf_mat, macro_f1, acc, roc_metrics |
|
|
def evaluate_saved_model( |
|
|
model_directory, |
|
|
id_class_dict_file, |
|
|
test_data_file, |
|
|
output_directory, |
|
|
output_prefix, |
|
|
predict=True, |
|
|
): |
|
|
""" |
|
|
Evaluate the fine-tuned model. |
|
|
|
|
|
**Parameters** |
|
|
|
|
|
model_directory : Path |
|
|
| Path to directory containing model |
|
|
id_class_dict_file : Path |
|
|
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data |
|
|
| (dictionary of format: numerical IDs: class_labels) |
|
|
test_data_file : Path |
|
|
| Path to directory containing test .dataset |
|
|
output_directory : Path |
|
|
| Path to directory where eval data will be saved |
|
|
output_prefix : str |
|
|
| Prefix for output files |
|
|
predict : bool |
|
|
| Whether or not to save eval predictions |
|
|
""" |
|
|
|
|
|
|
|
|
with open(id_class_dict_file, "rb") as f: |
|
|
id_class_dict = pickle.load(f) |
|
|
|
|
|
|
|
|
num_classes = get_num_classes(id_class_dict) |
|
|
|
|
|
|
|
|
test_data = load_and_filter(None, nproc, test_data_file) |
|
|
|
|
|
|
|
|
model = load_model( |
|
|
"CellClassifier", |
|
|
num_classes, |
|
|
model_directory, |
|
|
"eval", |
|
|
quantize=quantize, |
|
|
) |
|
|
|
|
|
|
|
|
result = evaluate_model( |
|
|
model, |
|
|
num_classes, |
|
|
id_class_dict, |
|
|
test_data, |
|
|
predict=predict, |
|
|
output_directory=output_directory, |
|
|
output_prefix="CellClassifier", |
|
|
) |
|
|
|
|
|
all_conf_mat_df = pd.DataFrame( |
|
|
result["conf_mat"], |
|
|
columns=id_class_dict.values(), |
|
|
index=id_class_dict.values(), |
|
|
) |
|
|
all_metrics = { |
|
|
"conf_matrix": all_conf_mat_df, |
|
|
"macro_f1": result["macro_f1"], |
|
|
"acc": result["acc"], |
|
|
} |
|
|
all_roc_metrics = None |
|
|
|
|
|
if num_classes == 2: |
|
|
mean_fpr = np.linspace(0, 1, 100) |
|
|
mean_tpr = result["roc_metrics"]["interp_tpr"] |
|
|
all_roc_auc = result["roc_metrics"]["auc"] |
|
|
all_roc_metrics = { |
|
|
"mean_tpr": mean_tpr, |
|
|
"mean_fpr": mean_fpr, |
|
|
"all_roc_auc": all_roc_auc, |
|
|
} |
|
|
all_metrics["all_roc_metrics"] = all_roc_metrics |
|
|
test_metrics_output_path = ( |
|
|
Path(output_directory) / f"{output_prefix}_test_metrics_dict" |
|
|
).with_suffix(".pkl") |
|
|
with open(test_metrics_output_path, "wb") as f: |
|
|
pickle.dump(all_metrics, f) |
|
|
|
|
|
return all_metrics |
|
|
|
|
|
def plot_conf_mat( |
|
|
conf_mat_dict, |
|
|
output_directory, |
|
|
output_prefix, |
|
|
custom_class_order=None, |
|
|
): |
|
|
""" |
|
|
Plot confusion matrix results of evaluating the fine-tuned model. |
|
|
|
|
|
**Parameters** |
|
|
|
|
|
conf_mat_dict : dict |
|
|
| Dictionary of model_name : confusion_matrix_DataFrame |
|
|
| (all_metrics["conf_matrix"] from self.validate) |
|
|
output_directory : Path |
|
|
| Path to directory where plots will be saved |
|
|
output_prefix : str |
|
|
| Prefix for output file |
|
|
custom_class_order : None, list |
|
|
| List of classes in custom order for plots. |
|
|
| Same order will be used for all models. |
|
|
""" |
|
|
|
|
|
for model_name in conf_mat_dict.keys(): |
|
|
plot_confusion_matrix( |
|
|
conf_mat_dict[model_name], |
|
|
model_name, |
|
|
output_directory, |
|
|
output_prefix, |
|
|
custom_class_order, |
|
|
) |
|
|
def plot_confusion_matrix( |
|
|
conf_mat_df, title, output_dir, output_prefix, custom_class_order |
|
|
): |
|
|
fig = plt.figure() |
|
|
fig.set_size_inches(10, 10) |
|
|
sns.set(font_scale=1) |
|
|
sns.set_style("whitegrid", {"axes.grid": False}) |
|
|
if custom_class_order is not None: |
|
|
conf_mat_df = conf_mat_df.reindex( |
|
|
index=custom_class_order, columns=custom_class_order |
|
|
) |
|
|
display_labels = generate_display_labels(conf_mat_df) |
|
|
conf_mat = preprocessing.normalize(conf_mat_df.to_numpy(), norm="l1") |
|
|
display = ConfusionMatrixDisplay( |
|
|
confusion_matrix=conf_mat, display_labels=display_labels |
|
|
) |
|
|
display.plot(cmap="Blues", values_format=".2g") |
|
|
plt.title(title) |
|
|
plt.show() |
|
|
|
|
|
output_file = (Path(output_dir) / f"{output_prefix}_conf_mat").with_suffix(".pdf") |
|
|
display.figure_.savefig(output_file, bbox_inches="tight") |
|
|
def generate_display_labels(conf_mat_df): |
|
|
display_labels = [] |
|
|
i = 0 |
|
|
for label in conf_mat_df.index: |
|
|
display_labels += [f"{label}\nn={conf_mat_df.iloc[i,:].sum():.0f}"] |
|
|
i = i + 1 |
|
|
return display_labels |
|
|
|
|
|
def plot_predictions( |
|
|
predictions_file, |
|
|
id_class_dict_file, |
|
|
title, |
|
|
output_directory, |
|
|
output_prefix, |
|
|
custom_class_order=None, |
|
|
kwargs_dict=None, |
|
|
): |
|
|
""" |
|
|
Plot prediction results of evaluating the fine-tuned model. |
|
|
|
|
|
**Parameters** |
|
|
|
|
|
predictions_file : path |
|
|
| Path of model predictions output to plot |
|
|
| (saved output from self.validate if predict_eval=True) |
|
|
| (or saved output from self.evaluate_saved_model) |
|
|
id_class_dict_file : Path |
|
|
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data |
|
|
| (dictionary of format: numerical IDs: class_labels) |
|
|
title : str |
|
|
| Title for legend containing class labels. |
|
|
output_directory : Path |
|
|
| Path to directory where plots will be saved |
|
|
output_prefix : str |
|
|
| Prefix for output file |
|
|
custom_class_order : None, list |
|
|
| List of classes in custom order for plots. |
|
|
| Same order will be used for all models. |
|
|
kwargs_dict : None, dict |
|
|
| Dictionary of kwargs to pass to plotting function. |
|
|
""" |
|
|
|
|
|
with open(predictions_file, "rb") as f: |
|
|
predictions = pickle.load(f) |
|
|
|
|
|
|
|
|
with open(id_class_dict_file, "rb") as f: |
|
|
id_class_dict = pickle.load(f) |
|
|
|
|
|
if isinstance(predictions, dict): |
|
|
if all( |
|
|
[ |
|
|
key in predictions.keys() |
|
|
for key in ["pred_ids", "label_ids", "predictions"] |
|
|
] |
|
|
): |
|
|
|
|
|
predictions_logits = np.array(predictions["predictions"]) |
|
|
true_ids = predictions["label_ids"] |
|
|
else: |
|
|
|
|
|
predictions_logits = predictions.predictions |
|
|
true_ids = predictions.label_ids |
|
|
|
|
|
num_classes = len(id_class_dict.keys()) |
|
|
num_predict_classes = predictions_logits.shape[1] |
|
|
assert num_classes == num_predict_classes |
|
|
classes = id_class_dict.values() |
|
|
true_labels = [id_class_dict[idx] for idx in true_ids] |
|
|
predictions_df = pd.DataFrame(predictions_logits, columns=classes) |
|
|
if custom_class_order is not None: |
|
|
predictions_df = predictions_df.reindex(columns=custom_class_order) |
|
|
predictions_df["true"] = true_labels |
|
|
custom_dict = dict(zip(classes, [i for i in range(len(classes))])) |
|
|
if custom_class_order is not None: |
|
|
custom_dict = dict( |
|
|
zip(custom_class_order, [i for i in range(len(custom_class_order))]) |
|
|
) |
|
|
predictions_df = predictions_df.sort_values( |
|
|
by=["true"], key=lambda x: x.map(custom_dict) |
|
|
) |
|
|
|
|
|
plot_predictions_eu( |
|
|
predictions_df, title, output_directory, output_prefix, kwargs_dict |
|
|
) |
|
|
def plot_predictions_eu(predictions_df, title, output_dir, output_prefix, kwargs_dict): |
|
|
sns.set(font_scale=2) |
|
|
plt.figure(figsize=(10, 10), dpi=150) |
|
|
label_colors, label_color_dict = make_colorbar(predictions_df, "true") |
|
|
predictions_df = predictions_df.drop(columns=["true"]) |
|
|
predict_colors_list = [label_color_dict[label] for label in predictions_df.columns] |
|
|
predict_label_list = [label for label in predictions_df.columns] |
|
|
predict_colors = pd.DataFrame( |
|
|
pd.Series(predict_colors_list, index=predict_label_list), columns=["predicted"] |
|
|
) |
|
|
|
|
|
default_kwargs_dict = { |
|
|
"row_cluster": False, |
|
|
"col_cluster": False, |
|
|
"row_colors": label_colors, |
|
|
"col_colors": predict_colors, |
|
|
"linewidths": 0, |
|
|
"xticklabels": False, |
|
|
"yticklabels": False, |
|
|
"center": 0, |
|
|
"cmap": "vlag", |
|
|
} |
|
|
|
|
|
if kwargs_dict is not None: |
|
|
default_kwargs_dict.update(kwargs_dict) |
|
|
g = sns.clustermap(predictions_df, **default_kwargs_dict) |
|
|
|
|
|
plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right") |
|
|
|
|
|
for label_color in list(label_color_dict.keys()): |
|
|
g.ax_col_dendrogram.bar( |
|
|
0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0 |
|
|
) |
|
|
|
|
|
g.ax_col_dendrogram.legend( |
|
|
title=f"{title}", |
|
|
loc="lower center", |
|
|
ncol=4, |
|
|
bbox_to_anchor=(0.5, 1), |
|
|
facecolor="white", |
|
|
) |
|
|
|
|
|
output_file = (Path(output_dir) / f"{output_prefix}_pred").with_suffix(".pdf") |
|
|
plt.savefig(output_file, bbox_inches="tight") |
|
|
def make_colorbar(embs_df, label): |
|
|
labels = list(embs_df[label]) |
|
|
|
|
|
cell_type_colors = gen_heatmap_class_colors(labels, embs_df) |
|
|
label_colors = pd.DataFrame(cell_type_colors, columns=[label]) |
|
|
|
|
|
|
|
|
label_color_dict = gen_heatmap_class_dict(labels, label_colors[label]) |
|
|
return label_colors, label_color_dict |
|
|
def gen_heatmap_class_colors(labels, df): |
|
|
pal = sns.cubehelix_palette( |
|
|
len(Counter(labels).keys()), |
|
|
light=0.9, |
|
|
dark=0.1, |
|
|
hue=1, |
|
|
reverse=True, |
|
|
start=1, |
|
|
rot=-2, |
|
|
) |
|
|
lut = dict(zip(map(str, Counter(labels).keys()), pal)) |
|
|
colors = pd.Series(labels, index=df.index).map(lut) |
|
|
return colors |
|
|
def gen_heatmap_class_dict(classes, label_colors_series): |
|
|
class_color_dict_df = pd.DataFrame( |
|
|
{"classes": classes, "color": label_colors_series} |
|
|
) |
|
|
class_color_dict_df = class_color_dict_df.drop_duplicates(subset=["classes"]) |
|
|
return dict(zip(class_color_dict_df["classes"], class_color_dict_df["color"])) |
|
|
|
|
|
|
|
|
for i in range(iter_step): |
|
|
|
|
|
model_directory = "model path" |
|
|
|
|
|
corpus_dir = "Pretrain_data" |
|
|
with open(corpus_dir + "/token_dictionary.pkl", "rb") as fp: |
|
|
gene_token_dict = pickle.load(fp) |
|
|
token_gene_dict = {v: k for k, v in gene_token_dict.items()} |
|
|
|
|
|
filter_data_dict={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]} |
|
|
training_args = { |
|
|
"num_train_epochs": 0.9, |
|
|
"learning_rate": 0.000804, |
|
|
"lr_scheduler_type": "polynomial", |
|
|
"warmup_steps": 1812, |
|
|
"weight_decay":0.258828, |
|
|
"per_device_train_batch_size": 12, |
|
|
"seed": 73, |
|
|
} |
|
|
|
|
|
cell_state_dict = {"state_key": "disease", "states": "all"} |
|
|
classifier='cell' |
|
|
filter_data=filter_data_dict |
|
|
split_sizes={"train": 0.8, "valid": 0.1, "test": 0.1} |
|
|
train_size = split_sizes["train"] |
|
|
valid_size = split_sizes["valid"] |
|
|
oos_test_size = split_sizes["test"] |
|
|
max_ncells=None |
|
|
freeze_layers = 2 |
|
|
num_crossval_splits = 1 |
|
|
forward_batch_size=200 |
|
|
nproc=16 |
|
|
rare_threshold=0 |
|
|
quantize=None |
|
|
|
|
|
|
|
|
train_ids = ["1447", "1600", "1462", "1558", "1300", "1508", "1358", "1678", "1561", "1304", "1610", "1430", "1472", "1707", "1726", "1504", "1425", "1617", "1631", "1735", "1582", "1722", "1622", "1630", "1290", "1479", "1371", "1549", "1515"] |
|
|
eval_ids = ["1422", "1510", "1539", "1606", "1702"] |
|
|
test_ids = ["1437", "1516", "1602", "1685", "1718"] |
|
|
|
|
|
train_test_id_split_dict = {"attr_key": "individual", |
|
|
"train": train_ids+eval_ids, |
|
|
"test": test_ids} |
|
|
train_valid_id_split_dict = {"attr_key": "individual", |
|
|
"train": train_ids, |
|
|
"eval": eval_ids} |
|
|
|
|
|
|
|
|
current_date = datetime.datetime.now() |
|
|
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}{current_date.strftime('%X').replace(':','')}" |
|
|
datestamp_min = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}" |
|
|
output_directory = "output path" |
|
|
|
|
|
if output_directory[-1:] != "/": |
|
|
output_directory = output_directory + "/" |
|
|
output_dir = f"{output_directory}{datestamp}_geneformer_diseaseClassifier/" |
|
|
output_prefix = "cm_classifier_test" |
|
|
subprocess.call(f"mkdir {output_dir}", shell=True) |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
prepare_data(input_data_file="example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset", |
|
|
output_directory=output_dir, |
|
|
output_prefix=output_prefix, |
|
|
split_id_dict=train_test_id_split_dict) |
|
|
|
|
|
with open(f"{output_dir}/{output_prefix}_id_class_dict.pkl", "rb") as f: |
|
|
id_class_dict = pickle.load(f) |
|
|
class_id_dict = {v: k for k, v in id_class_dict.items()} |
|
|
|
|
|
num_classes = get_num_classes(id_class_dict) |
|
|
|
|
|
data = load_and_filter(None, nproc, f"{output_dir}/{output_prefix}_labeled_train.dataset") |
|
|
data = data.shuffle(seed=42) |
|
|
|
|
|
|
|
|
results = [] |
|
|
all_conf_mat = np.zeros((num_classes, num_classes)) |
|
|
iteration_num = 1 |
|
|
split_id_dict=train_valid_id_split_dict |
|
|
|
|
|
for i in trange(num_crossval_splits): |
|
|
print( |
|
|
f"****** Validation split: {iteration_num}/{num_crossval_splits} ******\n" |
|
|
) |
|
|
ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}") |
|
|
if num_crossval_splits == 1: |
|
|
|
|
|
if split_id_dict is not None: |
|
|
data_dict = dict() |
|
|
data_dict["train"] = filter_by_dict( |
|
|
data, |
|
|
{split_id_dict["attr_key"]: split_id_dict["train"]}, |
|
|
nproc, |
|
|
) |
|
|
data_dict["test"] = filter_by_dict( |
|
|
data, |
|
|
{split_id_dict["attr_key"]: split_id_dict["eval"]}, |
|
|
nproc, |
|
|
) |
|
|
train_data = data_dict["train"] |
|
|
eval_data = data_dict["test"] |
|
|
|
|
|
trainer = train_classifier( |
|
|
model_directory, |
|
|
num_classes, |
|
|
train_data, |
|
|
eval_data, |
|
|
ksplit_output_dir, |
|
|
) |
|
|
|
|
|
result = evaluate_model( |
|
|
trainer.model, |
|
|
num_classes, |
|
|
id_class_dict, |
|
|
eval_data, |
|
|
True, |
|
|
ksplit_output_dir, |
|
|
output_prefix, |
|
|
) |
|
|
results += [result] |
|
|
all_conf_mat = all_conf_mat + result["conf_mat"] |
|
|
iteration_num = iteration_num + 1 |
|
|
|
|
|
all_conf_mat_df = pd.DataFrame( |
|
|
all_conf_mat, columns=id_class_dict.values(), index=id_class_dict.values() |
|
|
) |
|
|
all_metrics = { |
|
|
"conf_matrix": all_conf_mat_df, |
|
|
"macro_f1": [result["macro_f1"] for result in results], |
|
|
"acc": [result["acc"] for result in results], |
|
|
} |
|
|
all_roc_metrics = None |
|
|
if num_classes == 2: |
|
|
mean_fpr = np.linspace(0, 1, 100) |
|
|
all_tpr = [result["roc_metrics"]["interp_tpr"] for result in results] |
|
|
all_roc_auc = [result["roc_metrics"]["auc"] for result in results] |
|
|
all_tpr_wt = [result["roc_metrics"]["tpr_wt"] for result in results] |
|
|
mean_tpr, roc_auc, roc_auc_sd = eu.get_cross_valid_roc_metrics( |
|
|
all_tpr, all_roc_auc, all_tpr_wt |
|
|
) |
|
|
all_roc_metrics = { |
|
|
"mean_tpr": mean_tpr, |
|
|
"mean_fpr": mean_fpr, |
|
|
"all_roc_auc": all_roc_auc, |
|
|
"roc_auc": roc_auc, |
|
|
"roc_auc_sd": roc_auc_sd, |
|
|
} |
|
|
all_metrics["all_roc_metrics"] = all_roc_metrics |
|
|
save_eval_output=True |
|
|
if save_eval_output is True: |
|
|
eval_metrics_output_path = ( |
|
|
Path(output_dir) / f"cm_classifier_test_eval_metrics_dict" |
|
|
).with_suffix(".pkl") |
|
|
with open(eval_metrics_output_path, "wb") as f: |
|
|
pickle.dump(all_metrics, f) |
|
|
|
|
|
datestamp_min = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}_{current_date.strftime('%X').replace(':','')}" |
|
|
all_metrics_test = evaluate_saved_model( |
|
|
model_directory=f"{output_dir}/ksplit1/", |
|
|
id_class_dict_file=f"{output_dir}/{output_prefix}_id_class_dict.pkl", |
|
|
test_data_file=f"{output_dir}/{output_prefix}_labeled_test.dataset", |
|
|
output_directory=output_dir, |
|
|
output_prefix=output_prefix, |
|
|
) |
|
|
|
|
|
macro_f1_list.append(all_metrics_test['macro_f1']) |
|
|
acc_list.append(all_metrics_test['acc']) |
|
|
|
|
|
|
|
|
print("Macro F1: ", macro_f1_list) |
|
|
print("Accuracy: ", acc_list) |
|
|
|