{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "from tqdm.auto import tqdm, trange\n", "GPU_NUMBER = [0]\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \",\".join([str(s) for s in GPU_NUMBER])\n", "os.environ[\"NCCL_DEBUG\"] = \"INFO\"\n", "\n", "# imports\n", "from collections import Counter\n", "import datetime\n", "import pickle\n", "import subprocess\n", "import seaborn as sns; sns.set()\n", "from datasets import load_from_disk\n", "from sklearn.metrics import accuracy_score, f1_score\n", "from transformers import BertForSequenceClassification, BertForMaskedLM, BertForTokenClassification\n", "from transformers import Trainer\n", "from transformers.training_args import TrainingArguments\n", "import torch\n", "import pandas as pd\n", "from datasets.utils.logging import disable_progress_bar, enable_progress_bar\n", "from sklearn import preprocessing\n", "from sklearn.metrics import (\n", " ConfusionMatrixDisplay,\n", " accuracy_score,\n", " auc,\n", " confusion_matrix,\n", " f1_score,\n", " roc_curve,\n", ")\n", "from pathlib import Path\n", "import matplotlib.pyplot as plt\n", "\n", "import sys\n", "# sys.path.append('geneformer')\n", "from geneformer import DataCollatorForCellClassification\n", "\n", "macro_f1_list = []\n", "acc_list = []\n", "\n", "iter_step = 2\n", "\n", "def prepare_data(\n", " input_data_file,\n", " output_directory,\n", " output_prefix,\n", " split_id_dict=None,\n", " test_size=None,\n", " attr_to_split=None,\n", " attr_to_balance=None,\n", " max_trials=100,\n", " pval_threshold=0.1,\n", "):\n", " \"\"\"\n", " Prepare data for cell state or gene classification.\n", "\n", " **Parameters**\n", "\n", " input_data_file : Path\n", " | Path to directory containing .dataset input\n", " output_directory : Path\n", " | Path to directory where prepared data will be saved\n", " output_prefix : str\n", " | Prefix for output file\n", " split_id_dict : None, dict\n", " | Dictionary of IDs for train and test splits\n", " | Three-item dictionary with keys: attr_key, train, test\n", " | attr_key: key specifying name of column in .dataset that contains the IDs for the data splits\n", " | train: list of IDs in the attr_key column to include in the train split\n", " | test: list of IDs in the attr_key column to include in the test split\n", " | For example: {\"attr_key\": \"individual\",\n", " | \"train\": [\"patient1\", \"patient2\", \"patient3\", \"patient4\"],\n", " | \"test\": [\"patient5\", \"patient6\"]}\n", " test_size : None, float\n", " | Proportion of data to be saved separately and held out for test set\n", " | (e.g. 0.2 if intending hold out 20%)\n", " | If None, will inherit from split_sizes[\"test\"] from Classifier\n", " | The training set will be further split to train / validation in self.validate\n", " | Note: only available for CellClassifiers\n", " attr_to_split : None, str\n", " | Key for attribute on which to split data while balancing potential confounders\n", " | e.g. \"patient_id\" for splitting by patient while balancing other characteristics\n", " | Note: only available for CellClassifiers\n", " attr_to_balance : None, list\n", " | List of attribute keys on which to balance data while splitting on attr_to_split\n", " | e.g. [\"age\", \"sex\"] for balancing these characteristics while splitting by patient\n", " | Note: only available for CellClassifiers\n", " max_trials : None, int\n", " | Maximum number of trials of random splitting to try to achieve balanced other attributes\n", " | If no split is found without significant (p<0.05) differences in other attributes, will select best\n", " | Note: only available for CellClassifiers\n", " pval_threshold : None, float\n", " | P-value threshold to use for attribute balancing across splits\n", " | E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance\n", " \"\"\"\n", "\n", " if test_size is None:\n", " test_size = oos_test_size\n", "\n", " # prepare data and labels for classification\n", " data = load_and_filter(filter_data, nproc, input_data_file)\n", "\n", " if classifier == \"cell\":\n", " if \"label\" in data.features:\n", " logger.error(\n", " \"Column name 'label' must be reserved for class IDs. Please rename column.\"\n", " )\n", " raise\n", " elif classifier == \"gene\":\n", " if \"labels\" in data.features:\n", " logger.error(\n", " \"Column name 'labels' must be reserved for class IDs. Please rename column.\"\n", " )\n", " raise\n", "\n", " if (attr_to_split is not None) and (attr_to_balance is None):\n", " logger.error(\n", " \"Splitting by attribute while balancing confounders requires both attr_to_split and attr_to_balance to be defined.\"\n", " )\n", " raise\n", "\n", " if not isinstance(attr_to_balance, list):\n", " attr_to_balance = [attr_to_balance]\n", "\n", " if classifier == \"cell\":\n", " # remove cell states representing < rare_threshold of cells\n", " data = remove_rare(\n", " data, rare_threshold, cell_state_dict[\"state_key\"], nproc\n", " )\n", " # downsample max cells and max per class\n", " data = downsample_and_shuffle(\n", " data, max_ncells, None, cell_state_dict\n", " )\n", " # rename cell state column to \"label\"\n", " data = rename_cols(data, cell_state_dict[\"state_key\"])\n", "\n", " # convert classes to numerical labels and save as id_class_dict\n", " # of note, will label all genes in gene_class_dict\n", " # if (cross-)validating, genes will be relabeled in column \"labels\" for each split\n", " # at the time of training with Classifier.validate\n", " data, id_class_dict = label_classes(\n", " classifier, data, None, nproc\n", " )\n", "\n", " # save id_class_dict for future reference\n", " id_class_output_path = (\n", " Path(output_directory) / f\"{output_prefix}_id_class_dict\"\n", " ).with_suffix(\".pkl\")\n", " with open(id_class_output_path, \"wb\") as f:\n", " pickle.dump(id_class_dict, f)\n", "\n", " if split_id_dict is not None:\n", " data_dict = dict()\n", " data_dict[\"train\"] = filter_by_dict(\n", " data, {split_id_dict[\"attr_key\"]: split_id_dict[\"train\"]}, nproc\n", " )\n", " data_dict[\"test\"] = filter_by_dict(\n", " data, {split_id_dict[\"attr_key\"]: split_id_dict[\"test\"]}, nproc\n", " )\n", " train_data_output_path = (\n", " Path(output_directory) / f\"{output_prefix}_labeled_train\"\n", " ).with_suffix(\".dataset\")\n", " test_data_output_path = (\n", " Path(output_directory) / f\"{output_prefix}_labeled_test\"\n", " ).with_suffix(\".dataset\")\n", " data_dict[\"train\"].save_to_disk(str(train_data_output_path))\n", " data_dict[\"test\"].save_to_disk(str(test_data_output_path))\n", " elif (test_size is not None) and (classifier == \"cell\"):\n", " if 1 > test_size > 0:\n", " if attr_to_split is None:\n", " data_dict = data.train_test_split(\n", " test_size=test_size,\n", " stratify_by_column=None,\n", " seed=42,\n", " )\n", " train_data_output_path = (\n", " Path(output_directory) / f\"{output_prefix}_labeled_train\"\n", " ).with_suffix(\".dataset\")\n", " test_data_output_path = (\n", " Path(output_directory) / f\"{output_prefix}_labeled_test\"\n", " ).with_suffix(\".dataset\")\n", " data_dict[\"train\"].save_to_disk(str(train_data_output_path))\n", " data_dict[\"test\"].save_to_disk(str(test_data_output_path))\n", " else:\n", " data_dict, balance_df = cu.balance_attr_splits(\n", " data,\n", " attr_to_split,\n", " attr_to_balance,\n", " test_size,\n", " max_trials,\n", " pval_threshold,\n", " cell_state_dict[\"state_key\"],\n", " nproc,\n", " )\n", " balance_df.to_csv(\n", " f\"{output_directory}/{output_prefix}_train_test_balance_df.csv\"\n", " )\n", " train_data_output_path = (\n", " Path(output_directory) / f\"{output_prefix}_labeled_train\"\n", " ).with_suffix(\".dataset\")\n", " test_data_output_path = (\n", " Path(output_directory) / f\"{output_prefix}_labeled_test\"\n", " ).with_suffix(\".dataset\")\n", " data_dict[\"train\"].save_to_disk(str(train_data_output_path))\n", " data_dict[\"test\"].save_to_disk(str(test_data_output_path))\n", " else:\n", " data_output_path = (\n", " Path(output_directory) / f\"{output_prefix}_labeled\"\n", " ).with_suffix(\".dataset\")\n", " data.save_to_disk(str(data_output_path))\n", " print(data_output_path)\n", " else:\n", " data_output_path = (\n", " Path(output_directory) / f\"{output_prefix}_labeled\"\n", " ).with_suffix(\".dataset\")\n", " data.save_to_disk(str(data_output_path))\n", "\n", "def load_and_filter(filter_data, nproc, input_data_file):\n", " data = load_from_disk(input_data_file)\n", " if filter_data is not None:\n", " data = filter_by_dict(data, filter_data, nproc)\n", " return data\n", "# get number of classes for classifier\n", "def get_num_classes(id_class_dict):\n", " return len(set(id_class_dict.values()))\n", "\n", "def filter_by_dict(data, filter_data, nproc):\n", " for key, value in filter_data.items():\n", "\n", " def filter_data_by_criteria(example):\n", " return example[key] in value\n", "\n", " data = data.filter(filter_data_by_criteria, num_proc=nproc)\n", " if len(data) == 0:\n", " logger.error(\"No cells remain after filtering. Check filtering criteria.\")\n", " raise\n", " return data\n", "def remove_rare(data, rare_threshold, label, nproc):\n", " if rare_threshold > 0:\n", " total_cells = len(data)\n", " label_counter = Counter(data[label])\n", " nonrare_label_dict = {\n", " label: [k for k, v in label_counter if (v / total_cells) > rare_threshold]\n", " }\n", " data = filter_by_dict(data, nonrare_label_dict, nproc)\n", " return data\n", "def downsample_and_shuffle(data, max_ncells, max_ncells_per_class, cell_state_dict):\n", " data = data.shuffle(seed=42)\n", " num_cells = len(data)\n", " # if max number of cells is defined, then subsample to this max number\n", " if max_ncells is not None:\n", " if num_cells > max_ncells:\n", " data = data.select([i for i in range(max_ncells)])\n", " if max_ncells_per_class is not None:\n", " class_labels = data[cell_state_dict[\"state_key\"]]\n", " random.seed(42)\n", " subsample_indices = subsample_by_class(class_labels, max_ncells_per_class)\n", " data = data.select(subsample_indices)\n", " return data\n", "def rename_cols(data, state_key):\n", " data = data.rename_column(state_key, \"label\")\n", " return data\n", "def label_classes(classifier, data, gene_class_dict, nproc):\n", " if classifier == \"cell\":\n", " label_set = set(data[\"label\"])\n", " elif classifier == \"gene\":\n", " # remove cells without any of the target genes\n", " def if_contains_label(example):\n", " a = pu.flatten_list(gene_class_dict.values())\n", " b = example[\"input_ids\"]\n", " return not set(a).isdisjoint(b)\n", "\n", " data = data.filter(if_contains_label, num_proc=nproc)\n", " label_set = gene_class_dict.keys()\n", "\n", " if len(data) == 0:\n", " logger.error(\n", " \"No cells remain after filtering for target genes. Check target gene list.\"\n", " )\n", " raise\n", "\n", " class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))\n", " id_class_dict = {v: k for k, v in class_id_dict.items()}\n", "\n", " def classes_to_ids(example):\n", " if classifier == \"cell\":\n", " example[\"label\"] = class_id_dict[example[\"label\"]]\n", " elif classifier == \"gene\":\n", " example[\"labels\"] = label_gene_classes(\n", " example, class_id_dict, gene_class_dict\n", " )\n", " return example\n", "\n", " data = data.map(classes_to_ids, num_proc=nproc)\n", " return data, id_class_dict\n", "\n", "def train_classifier(\n", " model_directory,\n", " num_classes,\n", " train_data,\n", " eval_data,\n", " output_directory,\n", " predict=False,\n", " classifier='cell',\n", " no_eval=False,\n", " quantize = False,\n", " freeze_layers=2,\n", " ):\n", " \"\"\"\n", " Fine-tune model for cell state or gene classification.\n", "\n", " **Parameters**\n", "\n", " model_directory : Path\n", " | Path to directory containing model\n", " num_classes : int\n", " | Number of classes for classifier\n", " train_data : Dataset\n", " | Loaded training .dataset input\n", " | For cell classifier, labels in column \"label\".\n", " | For gene classifier, labels in column \"labels\".\n", " eval_data : None, Dataset\n", " | (Optional) Loaded evaluation .dataset input\n", " | For cell classifier, labels in column \"label\".\n", " | For gene classifier, labels in column \"labels\".\n", " output_directory : Path\n", " | Path to directory where fine-tuned model will be saved\n", " predict : bool\n", " | Whether or not to save eval predictions from trainer\n", " \"\"\"\n", "\n", " ##### Validate and prepare data #####\n", " train_data, eval_data = validate_and_clean_cols(\n", " train_data, eval_data, classifier\n", " )\n", " \n", " if (no_eval is True) and (eval_data is not None):\n", " logger.warning(\n", " \"no_eval set to True; model will be trained without evaluation.\"\n", " )\n", " eval_data = None\n", "\n", " if (classifier == \"gene\") and (predict is True):\n", " logger.warning(\n", " \"Predictions during training not currently available for gene classifiers; setting predict to False.\"\n", " )\n", " predict = False\n", "\n", " # ensure not overwriting previously saved model\n", " saved_model_test = os.path.join(output_directory, \"pytorch_model.bin\")\n", " if os.path.isfile(saved_model_test) is True:\n", " logger.error(\"Model already saved to this designated output directory.\")\n", " raise\n", " # make output directory\n", " # subprocess.call(f\"mkdir {output_directory}\", shell=True)\n", " os.makedirs(output_dir, exist_ok=True)\n", "\n", " ##### Load model and training args #####\n", " model = load_model(\n", " \"CellClassifier\",\n", " num_classes,\n", " model_directory,\n", " \"train\",\n", " quantize=quantize,\n", " )\n", " def_training_args, def_freeze_layers = get_default_train_args(\n", " model, classifier, train_data, output_directory\n", " )\n", "\n", " if training_args is not None:\n", " def_training_args.update(training_args)\n", " logging_steps = round(\n", " len(train_data) / def_training_args[\"per_device_train_batch_size\"] / 10\n", " )\n", " def_training_args[\"logging_steps\"] = logging_steps\n", " def_training_args[\"output_dir\"] = output_directory\n", " if eval_data is None:\n", " def_training_args[\"evaluation_strategy\"] = \"no\"\n", " def_training_args[\"load_best_model_at_end\"] = False\n", " training_args_init = TrainingArguments(**def_training_args)\n", "\n", " if freeze_layers is not None:\n", " def_freeze_layers = freeze_layers\n", "\n", " if def_freeze_layers > 0:\n", " modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers]\n", " for module in modules_to_freeze:\n", " for param in module.parameters():\n", " param.requires_grad = False\n", "\n", " ##### Fine-tune the model #####\n", " # define the data collator\n", " if classifier == \"cell\":\n", " data_collator = DataCollatorForCellClassification()\n", " elif self.classifier == \"gene\":\n", " data_collator = DataCollatorForGeneClassification()\n", "\n", " # create the trainer\n", " trainer = Trainer(\n", " model=model,\n", " args=training_args_init,\n", " data_collator=data_collator,\n", " train_dataset=train_data,\n", " eval_dataset=eval_data,\n", " compute_metrics=compute_metrics,\n", " )\n", "\n", " # train the classifier\n", " trainer.train()\n", " trainer.save_model(output_directory)\n", " if predict is True:\n", " # make eval predictions and save predictions and metrics\n", " predictions = trainer.predict(eval_data)\n", " prediction_output_path = f\"{output_directory}/predictions.pkl\"\n", " with open(prediction_output_path, \"wb\") as f:\n", " pickle.dump(predictions, f)\n", " trainer.save_metrics(\"eval\", predictions.metrics)\n", " return trainer\n", " \n", "def validate_and_clean_cols(train_data, eval_data, classifier):\n", " # validate that data has expected label column and remove others\n", " if classifier == \"cell\":\n", " label_col = \"label\"\n", " elif classifier == \"gene\":\n", " label_col = \"labels\"\n", "\n", " cols_to_keep = [label_col] + [\"input_ids\", \"length\"]\n", " if label_col not in train_data.column_names:\n", " logger.error(f\"train_data must contain column {label_col} with class labels.\")\n", " raise\n", " else:\n", " train_data = remove_cols(train_data, cols_to_keep)\n", "\n", " if eval_data is not None:\n", " if label_col not in eval_data.column_names:\n", " logger.error(\n", " f\"eval_data must contain column {label_col} with class labels.\"\n", " )\n", " raise\n", " else:\n", " eval_data = remove_cols(eval_data, cols_to_keep)\n", " return train_data, eval_data\n", " \n", "def remove_cols(data, cols_to_keep):\n", " other_cols = list(data.features.keys())\n", " other_cols = [ele for ele in other_cols if ele not in cols_to_keep]\n", " data = data.remove_columns(other_cols)\n", " return data\n", "\n", "def load_model(model_type, num_classes, model_directory, mode, quantize=False):\n", " if model_type == \"MTLCellClassifier-Quantized\":\n", " model_type = \"MTLCellClassifier\"\n", " quantize = True\n", "\n", " output_hidden_states = (mode == \"eval\")\n", "\n", " # Quantization logic\n", " if quantize:\n", " if model_type == \"MTLCellClassifier\":\n", " quantize_config = BitsAndBytesConfig(load_in_8bit=True)\n", " peft_config = None\n", " else:\n", " quantize_config = BitsAndBytesConfig(\n", " load_in_4bit=True,\n", " bnb_4bit_use_double_quant=True,\n", " bnb_4bit_quant_type=\"nf4\",\n", " bnb_4bit_compute_dtype=torch.bfloat16,\n", " )\n", " peft_config = LoraConfig(\n", " lora_alpha=128,\n", " lora_dropout=0.1,\n", " r=64,\n", " bias=\"none\",\n", " task_type=\"TokenClassification\",\n", " )\n", " else:\n", " quantize_config = None\n", " peft_config = None\n", "\n", " # Model class selection\n", " model_classes = {\n", " \"Pretrained\": BertForMaskedLM,\n", " \"GeneClassifier\": BertForTokenClassification,\n", " \"CellClassifier\": BertForSequenceClassification,\n", " \"MTLCellClassifier\": BertForMaskedLM\n", " }\n", "\n", " model_class = model_classes.get(model_type)\n", " if not model_class:\n", " raise ValueError(f\"Unknown model type: {model_type}\")\n", "\n", " # Model loading\n", " model_args = {\n", " \"pretrained_model_name_or_path\": model_directory,\n", " \"output_hidden_states\": output_hidden_states,\n", " \"output_attentions\": False,\n", " }\n", "\n", " if model_type != \"Pretrained\":\n", " model_args[\"num_labels\"] = num_classes\n", "\n", " if quantize_config:\n", " model_args[\"quantization_config\"] = quantize_config\n", " \n", " # Load the model\n", " model = model_class.from_pretrained(**model_args)\n", " ###########################\n", "\n", " if mode == \"eval\":\n", " model.eval()\n", "\n", " # Handle device placement and PEFT\n", " if not quantize:\n", " # Only move non-quantized models\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " model = model.to(device)\n", " elif peft_config:\n", " # Apply PEFT for quantized models (except MTLCellClassifier)\n", " model.enable_input_require_grads()\n", " model = get_peft_model(model, peft_config)\n", "\n", " return model\n", "\n", "def get_default_train_args(model, classifier, data, output_dir):\n", " num_layers = quant_layers(model)\n", " freeze_layers_get = 0\n", " batch_size = 12\n", " if classifier == \"cell\":\n", " epochs = 10\n", " evaluation_strategy = \"epoch\"\n", " load_best_model_at_end = True\n", " else:\n", " epochs = 1\n", " evaluation_strategy = \"no\"\n", " load_best_model_at_end = False\n", "\n", " if num_layers == 6:\n", " default_training_args = {\n", " \"learning_rate\": 5e-5,\n", " \"lr_scheduler_type\": \"linear\",\n", " \"warmup_steps\": 500,\n", " \"per_device_train_batch_size\": batch_size,\n", " \"per_device_eval_batch_size\": batch_size,\n", " }\n", " else:\n", " default_training_args = {\n", " \"per_device_train_batch_size\": batch_size,\n", " \"per_device_eval_batch_size\": batch_size,\n", " }\n", "\n", " training_args = {\n", " \"num_train_epochs\": epochs,\n", " \"do_train\": True,\n", " \"do_eval\": True,\n", " \"evaluation_strategy\": evaluation_strategy,\n", " \"logging_steps\": np.floor(len(data) / batch_size / 8), # 8 evals per epoch\n", " \"save_strategy\": \"epoch\",\n", " \"group_by_length\": False,\n", " \"length_column_name\": \"length\",\n", " \"disable_tqdm\": False,\n", " \"weight_decay\": 0.001,\n", " \"load_best_model_at_end\": load_best_model_at_end,\n", " }\n", " training_args.update(default_training_args)\n", "\n", " return training_args, freeze_layers_get\n", "\n", "def quant_layers(model):\n", " layer_nums = []\n", " for name, parameter in model.named_parameters():\n", " if \"layer\" in name:\n", " layer_nums += [int(name.split(\"layer.\")[1].split(\".\")[0])]\n", " return int(max(layer_nums)) + 1\n", "\n", "def compute_metrics(pred):\n", " labels = pred.label_ids\n", " preds = pred.predictions.argmax(-1)\n", " # calculate accuracy and macro f1 using sklearn's function\n", " acc = accuracy_score(labels, preds)\n", " macro_f1 = f1_score(labels, preds, average='macro')\n", " weighted_f1 = f1_score(labels, preds, average='weighted')\n", " return {\n", " 'accuracy': acc,\n", " 'macro_f1': macro_f1,\n", " 'weighted_f1': weighted_f1\n", " }\n", "def evaluate_model(\n", " model,\n", " num_classes,\n", " id_class_dict,\n", " eval_data,\n", " predict=False,\n", " output_directory=None,\n", " output_prefix=None,\n", "):\n", " \"\"\"\n", " Evaluate the fine-tuned model.\n", "\n", " **Parameters**\n", "\n", " model : nn.Module\n", " | Loaded fine-tuned model (e.g. trainer.model)\n", " num_classes : int\n", " | Number of classes for classifier\n", " id_class_dict : dict\n", " | Loaded _id_class_dict.pkl previously prepared by Classifier.prepare_data\n", " | (dictionary of format: numerical IDs: class_labels)\n", " eval_data : Dataset\n", " | Loaded evaluation .dataset input\n", " predict : bool\n", " | Whether or not to save eval predictions\n", " output_directory : Path\n", " | Path to directory where eval data will be saved\n", " output_prefix : str\n", " | Prefix for output files\n", " \"\"\"\n", "\n", " ##### Evaluate the model #####\n", " labels = id_class_dict.keys()\n", " y_pred, y_true, logits_list = classifier_predict(\n", " model, classifier, eval_data, 100\n", " )\n", " conf_mat, macro_f1, acc, roc_metrics = get_metrics(\n", " y_pred, y_true, logits_list, num_classes, labels\n", " )\n", " if predict is True:\n", " pred_dict = {\n", " \"pred_ids\": y_pred,\n", " \"label_ids\": y_true,\n", " \"predictions\": logits_list,\n", " }\n", " pred_dict_output_path = (\n", " Path(output_directory) / f\"{output_prefix}_pred_dict\"\n", " ).with_suffix(\".pkl\")\n", " with open(pred_dict_output_path, \"wb\") as f:\n", " pickle.dump(pred_dict, f)\n", " return {\n", " \"conf_mat\": conf_mat,\n", " \"macro_f1\": macro_f1,\n", " \"acc\": acc,\n", " \"roc_metrics\": roc_metrics,\n", " }\n", " \n", "def classifier_predict(model, classifier_type, evalset, forward_batch_size):\n", " if classifier_type == \"gene\":\n", " label_name = \"labels\"\n", " elif classifier_type == \"cell\":\n", " label_name = \"label\"\n", "\n", " predict_logits = []\n", " predict_labels = []\n", " model.eval()\n", "\n", " # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims\n", " evalset_len = len(evalset)\n", " max_divisible = find_largest_div(evalset_len, forward_batch_size)\n", " if len(evalset) - max_divisible == 1:\n", " evalset_len = max_divisible\n", "\n", " max_evalset_len = max(evalset.select([i for i in range(evalset_len)])[\"length\"])\n", "\n", " disable_progress_bar() # disable progress bar for preprocess_classifier_batch mapping\n", " for i in trange(0, evalset_len, forward_batch_size):\n", " max_range = min(i + forward_batch_size, evalset_len)\n", " batch_evalset = evalset.select([i for i in range(i, max_range)])\n", " padded_batch = preprocess_classifier_batch(\n", " batch_evalset, max_evalset_len, label_name\n", " )\n", " padded_batch.set_format(type=\"torch\")\n", "\n", " input_data_batch = padded_batch[\"input_ids\"]\n", " attn_msk_batch = padded_batch[\"attention_mask\"]\n", " label_batch = padded_batch[label_name]\n", " with torch.no_grad():\n", " outputs = model(\n", " input_ids=input_data_batch.to(\"cuda\"),\n", " attention_mask=attn_msk_batch.to(\"cuda\"),\n", " labels=label_batch.to(\"cuda\"),\n", " )\n", " predict_logits += [torch.squeeze(outputs.logits.to(\"cpu\"))]\n", " predict_labels += [torch.squeeze(label_batch.to(\"cpu\"))]\n", "\n", " enable_progress_bar()\n", " logits_by_cell = torch.cat(predict_logits)\n", " last_dim = len(logits_by_cell.shape) - 1\n", " all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[last_dim])\n", " labels_by_cell = torch.cat(predict_labels)\n", " all_labels = torch.flatten(labels_by_cell)\n", " logit_label_paired = [\n", " item\n", " for item in list(zip(all_logits.tolist(), all_labels.tolist()))\n", " if item[1] != -100\n", " ]\n", " y_pred = [vote(item[0]) for item in logit_label_paired]\n", " y_true = [item[1] for item in logit_label_paired]\n", " logits_list = [item[0] for item in logit_label_paired]\n", " return y_pred, y_true, logits_list\n", "\n", "def find_largest_div(N, K):\n", " rem = N % K\n", " if rem == 0:\n", " return N\n", " else:\n", " return N - rem\n", "def preprocess_classifier_batch(cell_batch, max_len, label_name):\n", " if max_len is None:\n", " max_len = max([len(i) for i in cell_batch[\"input_ids\"]])\n", "\n", " def pad_label_example(example):\n", " example[label_name] = np.pad(\n", " example[label_name],\n", " (0, max_len - len(example[\"input_ids\"])),\n", " mode=\"constant\",\n", " constant_values=-100,\n", " )\n", " example[\"input_ids\"] = np.pad(\n", " example[\"input_ids\"],\n", " (0, max_len - len(example[\"input_ids\"])),\n", " mode=\"constant\",\n", " constant_values=gene_token_dict.get(\"\"),\n", " )\n", " example[\"attention_mask\"] = (\n", " example[\"input_ids\"] != gene_token_dict.get(\"\")\n", " ).astype(int)\n", " return example\n", "\n", " padded_batch = cell_batch.map(pad_label_example)\n", " return padded_batch\n", "def vote(logit_list):\n", " m = max(logit_list)\n", " logit_list.index(m)\n", " indices = [i for i, x in enumerate(logit_list) if x == m]\n", " if len(indices) > 1:\n", " return \"tie\"\n", " else:\n", " return indices[0]\n", "def py_softmax(vector):\n", " e = np.exp(vector)\n", " return e / e.sum()\n", "def get_metrics(y_pred, y_true, logits_list, num_classes, labels):\n", " conf_mat = confusion_matrix(y_true, y_pred, labels=list(labels))\n", " macro_f1 = f1_score(y_true, y_pred, average=\"macro\")\n", " acc = accuracy_score(y_true, y_pred)\n", " roc_metrics = None # roc metrics not reported for multiclass\n", " if num_classes == 2:\n", " y_score = [py_softmax(item)[1] for item in logits_list]\n", " fpr, tpr, _ = roc_curve(y_true, y_score)\n", " mean_fpr = np.linspace(0, 1, 100)\n", " interp_tpr = np.interp(mean_fpr, fpr, tpr)\n", " interp_tpr[0] = 0.0\n", " tpr_wt = len(tpr)\n", " roc_auc = auc(fpr, tpr)\n", " roc_metrics = {\n", " \"fpr\": fpr,\n", " \"tpr\": tpr,\n", " \"interp_tpr\": interp_tpr,\n", " \"auc\": roc_auc,\n", " \"tpr_wt\": tpr_wt,\n", " }\n", " return conf_mat, macro_f1, acc, roc_metrics\n", "def evaluate_saved_model(\n", " model_directory,\n", " id_class_dict_file,\n", " test_data_file,\n", " output_directory,\n", " output_prefix,\n", " predict=True,\n", "):\n", " \"\"\"\n", " Evaluate the fine-tuned model.\n", "\n", " **Parameters**\n", "\n", " model_directory : Path\n", " | Path to directory containing model\n", " id_class_dict_file : Path\n", " | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data\n", " | (dictionary of format: numerical IDs: class_labels)\n", " test_data_file : Path\n", " | Path to directory containing test .dataset\n", " output_directory : Path\n", " | Path to directory where eval data will be saved\n", " output_prefix : str\n", " | Prefix for output files\n", " predict : bool\n", " | Whether or not to save eval predictions\n", " \"\"\"\n", "\n", " # load numerical id to class dictionary (id:class)\n", " with open(id_class_dict_file, \"rb\") as f:\n", " id_class_dict = pickle.load(f)\n", "\n", " # get number of classes for classifier\n", " num_classes = get_num_classes(id_class_dict)\n", "\n", " # load previously filtered and prepared data\n", " test_data = load_and_filter(None, nproc, test_data_file)\n", "\n", " # load previously fine-tuned model\n", " model = load_model(\n", " \"CellClassifier\",\n", " num_classes,\n", " model_directory,\n", " \"eval\",\n", " quantize=quantize,\n", " )\n", "\n", " # evaluate the model\n", " result = evaluate_model(\n", " model,\n", " num_classes,\n", " id_class_dict,\n", " test_data,\n", " predict=predict,\n", " output_directory=output_directory,\n", " output_prefix=\"CellClassifier\",\n", " )\n", "\n", " all_conf_mat_df = pd.DataFrame(\n", " result[\"conf_mat\"],\n", " columns=id_class_dict.values(),\n", " index=id_class_dict.values(),\n", " )\n", " all_metrics = {\n", " \"conf_matrix\": all_conf_mat_df,\n", " \"macro_f1\": result[\"macro_f1\"],\n", " \"acc\": result[\"acc\"],\n", " }\n", " all_roc_metrics = None # roc metrics not reported for multiclass\n", "\n", " if num_classes == 2:\n", " mean_fpr = np.linspace(0, 1, 100)\n", " mean_tpr = result[\"roc_metrics\"][\"interp_tpr\"]\n", " all_roc_auc = result[\"roc_metrics\"][\"auc\"]\n", " all_roc_metrics = {\n", " \"mean_tpr\": mean_tpr,\n", " \"mean_fpr\": mean_fpr,\n", " \"all_roc_auc\": all_roc_auc,\n", " }\n", " all_metrics[\"all_roc_metrics\"] = all_roc_metrics\n", " test_metrics_output_path = (\n", " Path(output_directory) / f\"{output_prefix}_test_metrics_dict\"\n", " ).with_suffix(\".pkl\")\n", " with open(test_metrics_output_path, \"wb\") as f:\n", " pickle.dump(all_metrics, f)\n", "\n", " return all_metrics\n", "\n", "def plot_conf_mat(\n", " conf_mat_dict,\n", " output_directory,\n", " output_prefix,\n", " custom_class_order=None,\n", "):\n", " \"\"\"\n", " Plot confusion matrix results of evaluating the fine-tuned model.\n", "\n", " **Parameters**\n", "\n", " conf_mat_dict : dict\n", " | Dictionary of model_name : confusion_matrix_DataFrame\n", " | (all_metrics[\"conf_matrix\"] from self.validate)\n", " output_directory : Path\n", " | Path to directory where plots will be saved\n", " output_prefix : str\n", " | Prefix for output file\n", " custom_class_order : None, list\n", " | List of classes in custom order for plots.\n", " | Same order will be used for all models.\n", " \"\"\"\n", "\n", " for model_name in conf_mat_dict.keys():\n", " plot_confusion_matrix(\n", " conf_mat_dict[model_name],\n", " model_name,\n", " output_directory,\n", " output_prefix,\n", " custom_class_order,\n", " )\n", "def plot_confusion_matrix(\n", " conf_mat_df, title, output_dir, output_prefix, custom_class_order\n", "):\n", " fig = plt.figure()\n", " fig.set_size_inches(10, 10)\n", " sns.set(font_scale=1)\n", " sns.set_style(\"whitegrid\", {\"axes.grid\": False})\n", " if custom_class_order is not None:\n", " conf_mat_df = conf_mat_df.reindex(\n", " index=custom_class_order, columns=custom_class_order\n", " )\n", " display_labels = generate_display_labels(conf_mat_df)\n", " conf_mat = preprocessing.normalize(conf_mat_df.to_numpy(), norm=\"l1\")\n", " display = ConfusionMatrixDisplay(\n", " confusion_matrix=conf_mat, display_labels=display_labels\n", " )\n", " display.plot(cmap=\"Blues\", values_format=\".2g\")\n", " plt.title(title)\n", " plt.show()\n", "\n", " output_file = (Path(output_dir) / f\"{output_prefix}_conf_mat\").with_suffix(\".pdf\")\n", " display.figure_.savefig(output_file, bbox_inches=\"tight\")\n", "def generate_display_labels(conf_mat_df):\n", " display_labels = []\n", " i = 0\n", " for label in conf_mat_df.index:\n", " display_labels += [f\"{label}\\nn={conf_mat_df.iloc[i,:].sum():.0f}\"]\n", " i = i + 1\n", " return display_labels\n", "\n", "def plot_predictions(\n", " predictions_file,\n", " id_class_dict_file,\n", " title,\n", " output_directory,\n", " output_prefix,\n", " custom_class_order=None,\n", " kwargs_dict=None,\n", "):\n", " \"\"\"\n", " Plot prediction results of evaluating the fine-tuned model.\n", "\n", " **Parameters**\n", "\n", " predictions_file : path\n", " | Path of model predictions output to plot\n", " | (saved output from self.validate if predict_eval=True)\n", " | (or saved output from self.evaluate_saved_model)\n", " id_class_dict_file : Path\n", " | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data\n", " | (dictionary of format: numerical IDs: class_labels)\n", " title : str\n", " | Title for legend containing class labels.\n", " output_directory : Path\n", " | Path to directory where plots will be saved\n", " output_prefix : str\n", " | Prefix for output file\n", " custom_class_order : None, list\n", " | List of classes in custom order for plots.\n", " | Same order will be used for all models.\n", " kwargs_dict : None, dict\n", " | Dictionary of kwargs to pass to plotting function.\n", " \"\"\"\n", " # load predictions\n", " with open(predictions_file, \"rb\") as f:\n", " predictions = pickle.load(f)\n", "\n", " # load numerical id to class dictionary (id:class)\n", " with open(id_class_dict_file, \"rb\") as f:\n", " id_class_dict = pickle.load(f)\n", "\n", " if isinstance(predictions, dict):\n", " if all(\n", " [\n", " key in predictions.keys()\n", " for key in [\"pred_ids\", \"label_ids\", \"predictions\"]\n", " ]\n", " ):\n", " # format is output from self.evaluate_saved_model\n", " predictions_logits = np.array(predictions[\"predictions\"])\n", " true_ids = predictions[\"label_ids\"]\n", " else:\n", " # format is output from self.validate if predict_eval=True\n", " predictions_logits = predictions.predictions\n", " true_ids = predictions.label_ids\n", "\n", " num_classes = len(id_class_dict.keys())\n", " num_predict_classes = predictions_logits.shape[1]\n", " assert num_classes == num_predict_classes\n", " classes = id_class_dict.values()\n", " true_labels = [id_class_dict[idx] for idx in true_ids]\n", " predictions_df = pd.DataFrame(predictions_logits, columns=classes)\n", " if custom_class_order is not None:\n", " predictions_df = predictions_df.reindex(columns=custom_class_order)\n", " predictions_df[\"true\"] = true_labels\n", " custom_dict = dict(zip(classes, [i for i in range(len(classes))]))\n", " if custom_class_order is not None:\n", " custom_dict = dict(\n", " zip(custom_class_order, [i for i in range(len(custom_class_order))])\n", " )\n", " predictions_df = predictions_df.sort_values(\n", " by=[\"true\"], key=lambda x: x.map(custom_dict)\n", " )\n", "\n", " plot_predictions_eu(\n", " predictions_df, title, output_directory, output_prefix, kwargs_dict\n", " )\n", "def plot_predictions_eu(predictions_df, title, output_dir, output_prefix, kwargs_dict):\n", " sns.set(font_scale=2)\n", " plt.figure(figsize=(10, 10), dpi=150)\n", " label_colors, label_color_dict = make_colorbar(predictions_df, \"true\")\n", " predictions_df = predictions_df.drop(columns=[\"true\"])\n", " predict_colors_list = [label_color_dict[label] for label in predictions_df.columns]\n", " predict_label_list = [label for label in predictions_df.columns]\n", " predict_colors = pd.DataFrame(\n", " pd.Series(predict_colors_list, index=predict_label_list), columns=[\"predicted\"]\n", " )\n", "\n", " default_kwargs_dict = {\n", " \"row_cluster\": False,\n", " \"col_cluster\": False,\n", " \"row_colors\": label_colors,\n", " \"col_colors\": predict_colors,\n", " \"linewidths\": 0,\n", " \"xticklabels\": False,\n", " \"yticklabels\": False,\n", " \"center\": 0,\n", " \"cmap\": \"vlag\",\n", " }\n", "\n", " if kwargs_dict is not None:\n", " default_kwargs_dict.update(kwargs_dict)\n", " g = sns.clustermap(predictions_df, **default_kwargs_dict)\n", "\n", " plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha=\"right\")\n", "\n", " for label_color in list(label_color_dict.keys()):\n", " g.ax_col_dendrogram.bar(\n", " 0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0\n", " )\n", "\n", " g.ax_col_dendrogram.legend(\n", " title=f\"{title}\",\n", " loc=\"lower center\",\n", " ncol=4,\n", " bbox_to_anchor=(0.5, 1),\n", " facecolor=\"white\",\n", " )\n", "\n", " output_file = (Path(output_dir) / f\"{output_prefix}_pred\").with_suffix(\".pdf\")\n", " plt.savefig(output_file, bbox_inches=\"tight\")\n", "def make_colorbar(embs_df, label):\n", " labels = list(embs_df[label])\n", "\n", " cell_type_colors = gen_heatmap_class_colors(labels, embs_df)\n", " label_colors = pd.DataFrame(cell_type_colors, columns=[label])\n", "\n", " # create dictionary for colors and classes\n", " label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])\n", " return label_colors, label_color_dict\n", "def gen_heatmap_class_colors(labels, df):\n", " pal = sns.cubehelix_palette(\n", " len(Counter(labels).keys()),\n", " light=0.9,\n", " dark=0.1,\n", " hue=1,\n", " reverse=True,\n", " start=1,\n", " rot=-2,\n", " )\n", " lut = dict(zip(map(str, Counter(labels).keys()), pal))\n", " colors = pd.Series(labels, index=df.index).map(lut)\n", " return colors\n", "def gen_heatmap_class_dict(classes, label_colors_series):\n", " class_color_dict_df = pd.DataFrame(\n", " {\"classes\": classes, \"color\": label_colors_series}\n", " )\n", " class_color_dict_df = class_color_dict_df.drop_duplicates(subset=[\"classes\"])\n", " return dict(zip(class_color_dict_df[\"classes\"], class_color_dict_df[\"color\"]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7a260f2ee53e46cda883751b4f9ee36f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Saving the dataset (0/3 shards): 0%| | 0/115367 [00:00