diff --git a/Geneformer-V2-104M/model.safetensors b/Geneformer-V2-104M/model.safetensors deleted file mode 100755 index 0cc349a7acab42b89dcd92a6a7b4e6bfaf53b5f4..0000000000000000000000000000000000000000 --- a/Geneformer-V2-104M/model.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:fff5cba29ddd8792991fa77b4872246fbe548a178cebda3775cdc72b67780e7f -size 417571156 diff --git a/Geneformer-V2-104M/training_args.bin b/Geneformer-V2-104M/training_args.bin deleted file mode 100755 index 443011e54067545b01bda3cf3979ed4775713c0b..0000000000000000000000000000000000000000 --- a/Geneformer-V2-104M/training_args.bin +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:0d8ddd9e4f35b5fe23a3adaae03aa4480705ca82eed546a488f970adb3752d9d -size 5496 diff --git a/Geneformer-V2-104M_CLcancer/model.safetensors b/Geneformer-V2-104M_CLcancer/model.safetensors deleted file mode 100755 index 659afe038802fc98201c0216f28bb466ed84f89d..0000000000000000000000000000000000000000 --- a/Geneformer-V2-104M_CLcancer/model.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:827738139bfed4bafa9d1f3df7c6146da2e3b85f7225076adc32c6eda0ba4357 -size 417571156 diff --git a/Geneformer-V2-104M_CLcancer/training_args.bin b/Geneformer-V2-104M_CLcancer/training_args.bin deleted file mode 100755 index 7199a4754edc3cbf4f76ce9659e4a8a176082dc5..0000000000000000000000000000000000000000 --- a/Geneformer-V2-104M_CLcancer/training_args.bin +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8cf8ce52b498253adc6df53197a99821fa145c19b8ae5eeb8d15be76b8b7ddb3 -size 4984 diff --git a/Geneformer-V2-316M/model.safetensors b/Geneformer-V2-316M/model.safetensors deleted file mode 100755 index f9c0c25c4a1df80ddc1455def715a9856152882c..0000000000000000000000000000000000000000 --- a/Geneformer-V2-316M/model.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:965ceccea81953d362081ef3843560a0e4fef88d396c28017881f1e94b1246f3 -size 1265455076 diff --git a/Geneformer-V2-316M/training_args.bin b/Geneformer-V2-316M/training_args.bin deleted file mode 100755 index 630bab56b199325b337a9969d30167f5b73b7815..0000000000000000000000000000000000000000 --- a/Geneformer-V2-316M/training_args.bin +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e45150f9a4ca34cb4e91ce79f65f3d99d9d66df9f66a37517a352d291008e0b8 -size 5432 diff --git a/MANIFEST.in b/MANIFEST.in index 764efdc5ada844c921601ae1ba29bcaf66001986..c3875d90a1e1ee1715279ba71ae3efc1a46643e8 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,9 +1,4 @@ -include geneformer/gene_median_dictionary_gc104M.pkl -include geneformer/gene_name_id_dict_gc104M.pkl -include geneformer/ensembl_mapping_dict_gc104M.pkl -include geneformer/token_dictionary_gc104M.pkl - -include geneformer/gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl -include geneformer/gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl -include geneformer/gene_dictionaries_30m/ensembl_mapping_dict_gc30M.pkl -include geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl \ No newline at end of file +include geneformer/gene_median_dictionary_gc95M.pkl +include geneformer/gene_name_id_dict_gc95M.pkl +include geneformer/ensembl_mapping_dict_gc95M.pkl +include geneformer/token_dictionary_gc95M.pkl diff --git a/README.md b/README.md index 5610efbd5b32f59af3a1e70b7ace9a94f21ac21a..e012e6c44d67ab502a3616ddb9a7a6e4c48faf14 100644 --- a/README.md +++ b/README.md @@ -1,36 +1,40 @@ --- datasets: ctheodoris/Genecorpus-30M license: apache-2.0 -tags: -- single-cell -- genomics --- # Geneformer Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology. - See [our manuscript](https://rdcu.be/ddrx0) for details of the original model trained on ~30 million transcriptomes in June 2021 and the initial report of our in silico perturbation and cell and gene classification strategies. -- See [our manuscript](https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf) for details of the expanded model, now trained on ~104 million transcriptomes, and our continual learning, multitask learning, and quantization strategies. +- See [our manuscript](https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf) for details of the expanded model trained on ~95 million transcriptomes in April 2024 and our continual learning, multitask learning, and quantization strategies. - See [geneformer.readthedocs.io](https://geneformer.readthedocs.io) for documentation. # Model Description -Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes representing a broad range of human tissues. Geneformer V1 was originally pretrained in June 2021 on [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M), a corpus comprised of ~30 million human single cell transcriptomes. We excluded cells with high mutational burdens (e.g. malignant cells and immortalized cell lines) that could lead to substantial network rewiring without companion genome sequencing to facilitate interpretation. The current updated Geneformer V2 is pretrained on ~104 million human single cell transcriptomes (non-cancer). The cancer continual learning V2 variant was continually pretrained on ~14 million cancer transcriptomes to yield a cancer domain-tuned model. +Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes representing a broad range of human tissues. Geneformer was originally pretrained in June 2021 on [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M), a corpus comprised of ~30 million single cell transcriptomes. We excluded cells with high mutational burdens (e.g. malignant cells and immortalized cell lines) that could lead to substantial network rewiring without companion genome sequencing to facilitate interpretation. Then, in April 2024, Geneformer was pretrained on ~95 million non-cancer transcriptomes, followed by continual learning on ~14 million cancer transcriptomes to yield a cancer domain-tuned model. -Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell scaled by their expression across the entire Genecorpus (~30M for V1, ~104M for V2). The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across the pretraining corpus to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by scaling them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable. +Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell scaled by their expression across the entire Genecorpus-30M. The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across the pretraining corpus to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by scaling them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable. The rank value encoding of each single cell’s transcriptome then proceeds through N layers of transformer encoder units, where N varies dependent on the model size. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels. We detail applications and results in [our manuscript](https://rdcu.be/ddrx0). -During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an induced pluripotent stem cell (iPSC) model of the disease. Overall, Geneformer represents a foundational AI model pretrained on a large-scale corpus human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets. +During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an induced pluripotent stem cell (iPSC) model of the disease. Overall, Geneformer represents a foundational deep learning model pretrained on a large-scale corpus human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets. The repository includes the following pretrained models: -- Geneformer-V1-10M: original model trained June 2021 on ~30M human single cell transcriptomes, 10M parameters, input size 2048, vocabulary ~25K protein-coding or non-coding RNA genes -- Geneformer-V2-104M and Geneformer-V2-316M: updated model trained Dec 2024 on ~104M human single cell transcriptomes, 104M or 316M parameters, input size 4096, vocabulary ~20K protein-coding genes +L=layers\ +M=millions of cells used for pretraining\ +i=input size\ +(pretraining date) -The current default model in the main directory of the repository is Geneformer-V2-316M. +- GF-6L-30M-i2048 (June 2021) +- GF-12L-30M-i2048 (June 2021) +- GF-12L-95M-i4096 (April 2024) +- GF-20L-95M-i4096 (April 2024) -The repository also contains fined tuned models in the fine_tuned_models directory and the cancer-tuned model following continual learning on ~14 million cancer cells, Geneformer-V2-104M_CLcancer. +The current default model in the main directory of the repository is GF-12L-95M-i4096. + +The repository also contains fined tuned models in the fine_tuned_models directory and the cancer-tuned model following continual learning on ~14 million cancer cells, GF-12L-95M-i4096_CLcancer. # Application The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification. @@ -78,13 +82,9 @@ For usage, see [examples](https://huggingface.co/ctheodoris/Geneformer/tree/main - extracting and plotting cell embeddings - in silico perturbation -Please also see [here](https://tinyurl.com/geneformertutorial) for a quickstart tutorial for predicting candidate therapeutic targets with Geneformer. - -Complete documentation is available at https://geneformer.readthedocs.io/en/latest/. - Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the [example_input_files directory](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files) in the dataset repository, but these only represent a few example fine-tuning applications. -Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.). Importantly, as usual for deep learning models, there are no uniformly applicable default hyperparameters for Geneformer. +Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.). # Citations - C V Theodoris#, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor#. Transfer learning enables predictions in network biology. _**Nature**_, 31 May 2023. (#co-corresponding authors) diff --git a/config.json b/config.json index 6bc648aa565eabb748a6a43ee4def5032a0d5237..86e20c35e6f257f0daeb00ebb92a0751d12d8fff 100644 --- a/config.json +++ b/config.json @@ -2,22 +2,22 @@ "architectures": [ "BertForMaskedLM" ], - "attention_probs_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.02, "classifier_dropout": null, "hidden_act": "relu", - "hidden_dropout_prob": 0.1, - "hidden_size": 1152, + "hidden_dropout_prob": 0.02, + "hidden_size": 512, "initializer_range": 0.02, - "intermediate_size": 4608, + "intermediate_size": 1024, "layer_norm_eps": 1e-12, "max_position_embeddings": 4096, "model_type": "bert", - "num_attention_heads": 18, - "num_hidden_layers": 18, + "num_attention_heads": 8, + "num_hidden_layers": 12, "pad_token_id": 0, "position_embedding_type": "absolute", "torch_dtype": "float32", - "transformers_version": "4.44.2", + "transformers_version": "4.37.1", "type_vocab_size": 2, "use_cache": true, "vocab_size": 20275 diff --git a/docs/source/geneformer.in_silico_perturber.rst b/docs/source/geneformer.in_silico_perturber.rst index 419b1c74c126d19e2c0a915606c90932330da0c0..fab76dea3c46244ab15d3d77552bc538535675e5 100644 --- a/docs/source/geneformer.in_silico_perturber.rst +++ b/docs/source/geneformer.in_silico_perturber.rst @@ -5,4 +5,4 @@ geneformer.in\_silico\_perturber :members: :undoc-members: :show-inheritance: - :exclude-members: valid_option_dict, validate_options, apply_additional_filters, isp_perturb_all, isp_perturb_set, , isp_perturb_all_special, isp_perturb_set_special, update_perturbation_dictionary + :exclude-members: valid_option_dict, validate_options, apply_additional_filters, isp_perturb_all, isp_perturb_set, update_perturbation_dictionary diff --git a/examples/cell_classification.ipynb b/examples/cell_classification.ipynb index 64dd4323cb562a31b1641a000dab8aa5f59ac951..779eb5f4f0560db7ebc257b170b3bfea1ee7ecad 100644 --- a/examples/cell_classification.ipynb +++ b/examples/cell_classification.ipynb @@ -13,7 +13,7 @@ "id": "1792e51c-86c3-406f-be5a-273c4e4aec20", "metadata": {}, "source": [ - "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses previously optimized hyperparameters, but one can optimize hyperparameters with the argument n_hyperopt_trials=n in cc.validate() where n>0 and represents the number of trials for hyperparameter optimization. Importantly, these hyperparameters do not represent uniformly applicable or recommended hyperparameters." + "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses previously optimized hyperparameters, but one can optimize hyperparameters with the argument n_hyperopt_trials=n in cc.validate() where n>0 and represents the number of trials for hyperparameter optimization." ] }, { @@ -68,8 +68,6 @@ " \"per_device_train_batch_size\": 12,\n", " \"seed\": 73,\n", "}\n", - "\n", - "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n", "cc = Classifier(classifier=\"cell\",\n", " cell_state_dict = {\"state_key\": \"disease\", \"states\": \"all\"},\n", " filter_data=filter_data_dict,\n", @@ -78,7 +76,6 @@ " freeze_layers = 2,\n", " num_crossval_splits = 1,\n", " forward_batch_size=200,\n", - " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n", " nproc=16)" ] }, @@ -128,7 +125,7 @@ " \"train\": train_ids+eval_ids,\n", " \"test\": test_ids}\n", "\n", - "# Example input_data_file for 30M model: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n", + "# Example input_data_file: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n", "cc.prepare_data(input_data_file=\"/path/to/human_dcm_hcm_nf_2048_w_length.dataset\",\n", " output_directory=output_dir,\n", " output_prefix=output_prefix,\n", @@ -263,8 +260,8 @@ " \"train\": train_ids,\n", " \"eval\": eval_ids}\n", "\n", - "# V1 model: https://huggingface.co/ctheodoris/Geneformer/blob/main/Geneformer-V1-10M/model.safetensors\n", - "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\", # OF NOTE: SET TO V1 MODEL ABOVE, PROVIDE V1 MODEL PATH HERE\n", + "# 6 layer Geneformer: https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors\n", + "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\",\n", " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled_train.dataset\",\n", " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n", " output_directory=output_dir,\n", @@ -449,7 +446,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/examples/distributed_multitask_cell_classification.ipynb b/examples/distributed_multitask_cell_classification.ipynb deleted file mode 100644 index dfc601ddd6745982175cbd4c3cab6d2db581f1cf..0000000000000000000000000000000000000000 --- a/examples/distributed_multitask_cell_classification.ipynb +++ /dev/null @@ -1,149 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "b3266a7b", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import torch\n", - "from geneformer import MTLClassifier" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3e12ac9f", - "metadata": {}, - "outputs": [], - "source": [ - "# Define paths\n", - "pretrained_path = \"/path/to/pretrained/Geneformer/model\" \n", - "# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n", - "train_path = \"/path/to/train/data.dataset\"\n", - "val_path = \"/path/to/val/data.dataset\"\n", - "test_path = \"/path/to/test/data.dataset\"\n", - "results_dir = \"/path/to/results/directory\"\n", - "model_save_path = \"/path/to/model/save/path\"\n", - "tensorboard_log_dir = \"/path/to/tensorboard/log/dir\"\n", - "\n", - "# Define tasks and hyperparameters\n", - "# task_columns should be a list of column names from your dataset\n", - "# Each column represents a specific classification task (e.g. cell type, disease state)\n", - "task_columns = [\"cell_type\", \"disease_state\"] # Example task columns" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c9bd7562", - "metadata": {}, - "outputs": [], - "source": [ - "# Check GPU environment\n", - "num_gpus = torch.cuda.device_count()\n", - "use_distributed = num_gpus > 1\n", - "print(f\"Number of GPUs detected: {num_gpus}\")\n", - "print(f\"Using distributed training: {use_distributed}\")\n", - "\n", - "# Set environment variables for distributed training when multiple GPUs are available\n", - "if use_distributed:\n", - " os.environ[\"MASTER_ADDR\"] = \"localhost\" # hostname\n", - " os.environ[\"MASTER_PORT\"] = \"12355\" # Choose an available port\n", - " print(\"Distributed environment variables set.\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b6ff3618", - "metadata": {}, - "outputs": [], - "source": [ - "#Define Hyperparameters for Optimization\n", - "hyperparameters = {\n", - " \"learning_rate\": {\"type\": \"float\", \"low\": 1e-5, \"high\": 1e-3, \"log\": True},\n", - " \"warmup_ratio\": {\"type\": \"float\", \"low\": 0.005, \"high\": 0.01},\n", - " \"weight_decay\": {\"type\": \"float\", \"low\": 0.01, \"high\": 0.1},\n", - " \"dropout_rate\": {\"type\": \"float\", \"low\": 0.0, \"high\": 0.7},\n", - " \"lr_scheduler_type\": {\"type\": \"categorical\", \"choices\": [\"cosine\"]},\n", - " \"task_weights\": {\"type\": \"float\", \"low\": 0.1, \"high\": 2.0},\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f665c5a7", - "metadata": {}, - "outputs": [], - "source": [ - "mc = MTLClassifier(\n", - " task_columns=task_columns, # Our defined classification tasks\n", - " study_name=\"MTLClassifier_distributed\",\n", - " pretrained_path=pretrained_path,\n", - " train_path=train_path,\n", - " val_path=val_path,\n", - " test_path=test_path,\n", - " model_save_path=model_save_path,\n", - " results_dir=results_dir,\n", - " tensorboard_log_dir=tensorboard_log_dir,\n", - " hyperparameters=hyperparameters,\n", - " # Distributed training parameters\n", - " distributed_training=use_distributed, # Enable distributed training if multiple GPUs available\n", - " master_addr=\"localhost\" if use_distributed else None,\n", - " master_port=\"12355\" if use_distributed else None,\n", - " # Other training parameters\n", - " n_trials=15, # Number of trials for hyperparameter optimization\n", - " epochs=1, # Number of training epochs (1 suggested to prevent overfitting)\n", - " batch_size=8, # Adjust based on available GPU memory\n", - " gradient_accumulation_steps=4, # Accumulate gradients over multiple steps\n", - " gradient_clipping=True, # Enable gradient clipping for stability\n", - " max_grad_norm=1.0, # Set maximum gradient norm\n", - " seed=42\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f69f7b6a", - "metadata": {}, - "outputs": [], - "source": [ - "# Run Hyperparameter Optimization with Distributed Training\n", - "if __name__ == \"__main__\":\n", - " # This guard is required for distributed training to prevent\n", - " # infinite subprocess spawning when using torch.multiprocessing\n", - " mc.run_optuna_study()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3affd5dd", - "metadata": {}, - "outputs": [], - "source": [ - "# Evaluate the Model on Test Data\n", - "if __name__ == \"__main__\":\n", - " mc.load_and_evaluate_test_model()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "bio", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.12.8" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/extract_and_plot_cell_embeddings.ipynb b/examples/extract_and_plot_cell_embeddings.ipynb index 8571064ab3a3a3f35d3fcf2d90a64fc0ebcb071f..fe9474b4e1eb132250574717d780f74f8f4bd188 100644 --- a/examples/extract_and_plot_cell_embeddings.ipynb +++ b/examples/extract_and_plot_cell_embeddings.ipynb @@ -18,7 +18,6 @@ "outputs": [], "source": [ "# initiate EmbExtractor\n", - "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n", "embex = EmbExtractor(model_type=\"CellClassifier\",\n", " num_classes=3,\n", " filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n", @@ -27,13 +26,12 @@ " emb_label=[\"disease\",\"cell_type\"],\n", " labels_to_plot=[\"disease\"],\n", " forward_batch_size=200,\n", - " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n", " nproc=16)\n", "\n", "# extracts embedding from input data\n", "# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n", - "# example dataset for V1 model series: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n", - "embs = embex.extract_embs(\"../fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224\", # example V1 fine-tuned model\n", + "# example dataset: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n", + "embs = embex.extract_embs(\"../fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224\",\n", " \"path/to/input_data/\",\n", " \"path/to/output_directory/\",\n", " \"output_prefix\")\n" @@ -131,7 +129,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/examples/gene_classification.ipynb b/examples/gene_classification.ipynb index b739754a95c23c8f74f9cf4a85e05da9c2af58a8..544e7130ae9e4918eb0cc9f280a1fe9c34c666f5 100644 --- a/examples/gene_classification.ipynb +++ b/examples/gene_classification.ipynb @@ -13,7 +13,7 @@ "id": "79539e95-2c9c-4162-835c-f0d158abb15d", "metadata": {}, "source": [ - "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses default hyperparameters, but please see the \"hyperparam_optimiz_for_disease_classifier\" script for an example of how to tune hyperparameters for downstream applications. Importantly, these hyperparameters do not represent uniformly applicable or recommended hyperparameters." + "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses default hyperparameters, but please see the \"hyperparam_optimiz_for_disease_classifier\" script for an example of how to tune hyperparameters for downstream applications." ] }, { @@ -71,14 +71,12 @@ } ], "source": [ - "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n", "cc = Classifier(classifier=\"gene\",\n", " gene_class_dict = gene_class_dict,\n", " max_ncells = 10_000,\n", " freeze_layers = 4,\n", " num_crossval_splits = 5,\n", " forward_batch_size=200,\n", - " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n", " nproc=16)" ] }, @@ -104,7 +102,7 @@ } ], "source": [ - "# Example input_data_file for 30M model series: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/gene_classification/dosage_sensitive_tfs/gc-30M_sample50k.dataset\n", + "# Example input_data_file: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/gene_classification/dosage_sensitive_tfs/gc-30M_sample50k.dataset\n", "cc.prepare_data(input_data_file=\"/path/to/gc-30M_sample50k.dataset\",\n", " output_directory=output_dir,\n", " output_prefix=output_prefix)" @@ -842,8 +840,8 @@ } ], "source": [ - "# V1 model: https://huggingface.co/ctheodoris/Geneformer/blob/main/Geneformer-V1-10M/model.safetensors\n", - "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\", # OF NOTE: SET TO V1 MODEL ABOVE, PROVIDE V1 MODEL PATH HERE\n", + "# 6 layer Geneformer: https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors\n", + "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\",\n", " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n", " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n", " output_directory=output_dir,\n", @@ -1065,14 +1063,12 @@ } ], "source": [ - "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n", "cc = Classifier(classifier=\"gene\",\n", " gene_class_dict = gene_class_dict,\n", " max_ncells = 10_000,\n", " freeze_layers = 4,\n", " num_crossval_splits = 0,\n", " forward_batch_size=200,\n", - " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n", " nproc=16)" ] }, @@ -1219,8 +1215,8 @@ } ], "source": [ - "# V1 model: https://huggingface.co/ctheodoris/Geneformer/blob/main/Geneformer-V1-10M/model.safetensors\n", - "trainer_test = cc.train_all_data(model_directory=\"/path/to/Geneformer\", # OF NOTE: SET TO V1 MODEL ABOVE, PROVIDE V1 MODEL PATH HERE\n", + "# 6 layer Geneformer: https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors\n", + "trainer_test = cc.train_all_data(model_directory=\"/path/to/Geneformer\",\n", " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n", " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n", " output_directory=output_dir,\n", @@ -1244,7 +1240,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/examples/in_silico_perturbation.ipynb b/examples/in_silico_perturbation.ipynb index 00607d24415b5a6034fca467ec90c2e76ae72d43..6d93efff57ba6a0fe85b85bf8ac188b80e515f3b 100644 --- a/examples/in_silico_perturbation.ipynb +++ b/examples/in_silico_perturbation.ipynb @@ -39,19 +39,17 @@ "\n", "filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n", "\n", - "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n", - "embex = EmbExtractor(model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n", + "embex = EmbExtractor(model_type=\"CellClassifier\",\n", " num_classes=3,\n", " filter_data=filter_data_dict,\n", " max_ncells=1000,\n", " emb_layer=0,\n", " summary_stat=\"exact_mean\",\n", " forward_batch_size=256,\n", - " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n", " nproc=16)\n", "\n", "state_embs_dict = embex.get_state_embs(cell_states_to_model,\n", - " \"../fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224\", # example 30M fine-tuned model\n", + " \"path/to/model\",\n", " \"path/to/input_data\",\n", " \"path/to/output_directory\",\n", " \"output_prefix\")" @@ -66,15 +64,14 @@ }, "outputs": [], "source": [ - "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n", "isp = InSilicoPerturber(perturb_type=\"delete\",\n", " perturb_rank_shift=None,\n", " genes_to_perturb=\"all\",\n", " combos=0,\n", " anchor_gene=None,\n", - " model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n", + " model_type=\"CellClassifier\",\n", " num_classes=3,\n", - " emb_mode=\"cell\", # OF NOTE: SET TO \"CELL\" FOR V1 MODEL. FOR V2, SHOULD BE \"CLS\" (current default).\n", + " emb_mode=\"cell\",\n", " cell_emb_style=\"mean_pool\",\n", " filter_data=filter_data_dict,\n", " cell_states_to_model=cell_states_to_model,\n", @@ -82,7 +79,6 @@ " max_ncells=2000,\n", " emb_layer=0,\n", " forward_batch_size=400,\n", - " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n", " nproc=16)" ] }, @@ -94,10 +90,9 @@ "outputs": [], "source": [ "# outputs intermediate files from in silico perturbation\n", - "\n", - "isp.perturb_data(\"../fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224\", # example V1 fine-tuned model\n", + "isp.perturb_data(\"path/to/model\",\n", " \"path/to/input_data\",\n", - " \"path/to/isp_output_directory\",\n", + " \"path/to/output_directory\",\n", " \"output_prefix\")" ] }, @@ -108,13 +103,11 @@ "metadata": {}, "outputs": [], "source": [ - "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n", "ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n", " genes_perturbed=\"all\",\n", " combos=0,\n", " anchor_gene=None,\n", - " cell_states_to_model=cell_states_to_model,\n", - " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL SINCE V1 WAS USED FOR IN SILICO PERTURBATION ABOVE)" + " cell_states_to_model=cell_states_to_model)" ] }, { @@ -125,9 +118,9 @@ "outputs": [], "source": [ "# extracts data from intermediate files and processes stats to output in final .csv\n", - "ispstats.get_stats(\"path/to/isp_output_directory\", # this should be the directory \n", + "ispstats.get_stats(\"path/to/input_data\",\n", " None,\n", - " \"path/to/isp_stats_output_directory\",\n", + " \"path/to/output_directory\",\n", " \"output_prefix\")" ] } @@ -148,7 +141,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.10.11" } }, "nbformat": 4, diff --git a/examples/multitask_cell_classification.ipynb b/examples/multitask_cell_classification.ipynb index 998e678d7eb812dbf6e5e3764b982e9d6620fd63..b3f13b7477c7fb8797bf871b90f943877fb61029 100644 --- a/examples/multitask_cell_classification.ipynb +++ b/examples/multitask_cell_classification.ipynb @@ -286,7 +286,7 @@ " filter_data_dict=filter_data_dict,\n", " max_ncells=1000, # Number of cells to extract embeddings for\n", " emb_layer=0, # Use the second to last layer\n", - " emb_mode = \"cls\", # Use CLS token embedding for V2 model\n", + " emb_mode = \"cls\",\n", " summary_stat=\"exact_mean\",\n", " forward_batch_size=8, # Adjust based on available GPU memory\n", " nproc=4\n", @@ -324,7 +324,7 @@ " perturb_type=perturb_type,\n", " genes_to_perturb=\"all\", # Perturb all genes\n", " model_type=\"MTLCellClassifier-Quantized\", # Use quantized MTL model\n", - " emb_mode=\"cls\", # Use CLS token embedding for V2 model\n", + " emb_mode=\"cls\", # Use CLS token embedding\n", " cell_states_to_model=cell_states_to_model,\n", " state_embs_dict=state_embs_dict,\n", " max_ncells=1000, # Number of cells to perturb (larger number increases power)\n", @@ -412,7 +412,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/examples/tokenizing_scRNAseq_data.ipynb b/examples/tokenizing_scRNAseq_data.ipynb index 7f7331fb63d46567465ddcc6cea5560e09a40e24..6b2ea01eec34ebfc784b90915ca0c6a4b684e2d7 100644 --- a/examples/tokenizing_scRNAseq_data.ipynb +++ b/examples/tokenizing_scRNAseq_data.ipynb @@ -12,7 +12,7 @@ }, { "cell_type": "markdown", - "id": "1fe86f48-5578-47df-b373-58c21ec170ab", + "id": "350e6252-b783-494b-9767-f087eb868a15", "metadata": {}, "source": [ "#### Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.\n", @@ -25,17 +25,11 @@ "\n", "#### Additionally, if the original .loom file contains a cell column attribute called \"filter_pass\", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with \"1\" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.\n", "\n", - "#### If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer." - ] - }, - { - "cell_type": "markdown", - "id": "32c69493-4e5a-4b07-8dc1-958ff2ee7d0b", - "metadata": {}, - "source": [ - "**********************************************************************************************************\n", - "#### OF NOTE: Please ensure the correct token dictionary, gene median file, special token setting, and model input size is used for the correct model version.\n", - "#### Current defaults are for V2 model series. To auto-select the correct settings for V1, set model_version argument to \"V1\"." + "#### If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer.\n", + "\n", + "#### OF NOTE: PLEASE ENSURE THE CORRECT TOKEN DICTIONARY AND GENE MEDIAN FILE IS USED FOR THE CORRECT MODEL.\n", + "\n", + "#### The 95M model series also require the special_token argument to be set to True and model_input_size to be 4096." ] }, { @@ -55,7 +49,7 @@ "metadata": {}, "outputs": [], "source": [ - "tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}, nproc=16) # for V1 model, set model_version=\"V1\"\n", + "tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}, nproc=16)\n", "tk.tokenize_data(\"loom_data_directory\", \n", " \"output_directory\", \n", " \"output_prefix\", \n", @@ -79,7 +73,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.10.11" } }, "nbformat": 4, diff --git a/Geneformer-V2-104M/config.json b/fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json similarity index 68% rename from Geneformer-V2-104M/config.json rename to fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json index f1f165b1d6df5149022c0c40a23a31fd258d98cf..bc8099f84af0bd3e35d700a7135dd417e38f6bea 100755 --- a/Geneformer-V2-104M/config.json +++ b/fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json @@ -2,22 +2,22 @@ "architectures": [ "BertForMaskedLM" ], - "attention_probs_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.02, "classifier_dropout": null, "hidden_act": "relu", - "hidden_dropout_prob": 0.1, - "hidden_size": 768, + "hidden_dropout_prob": 0.02, + "hidden_size": 512, "initializer_range": 0.02, - "intermediate_size": 3072, + "intermediate_size": 1024, "layer_norm_eps": 1e-12, "max_position_embeddings": 4096, "model_type": "bert", - "num_attention_heads": 12, + "num_attention_heads": 8, "num_hidden_layers": 12, "pad_token_id": 0, "position_embedding_type": "absolute", "torch_dtype": "float32", - "transformers_version": "4.44.2", + "transformers_version": "4.37.2", "type_vocab_size": 2, "use_cache": true, "vocab_size": 20275 diff --git a/fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin b/fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin new file mode 100755 index 0000000000000000000000000000000000000000..87625b1b8fe02c6aa0fc3ffd8c746275570e589d --- /dev/null +++ b/fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:07b28d8c7bb789d59755c42d32f6182cc04d2cf34aafaa6397aa50e4fdf1a9b4 +size 152363342 diff --git a/fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/config.json b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/config.json similarity index 100% rename from fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/config.json rename to fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/config.json diff --git a/fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/optimizer.pt b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/optimizer.pt similarity index 100% rename from fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/optimizer.pt rename to fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/optimizer.pt diff --git a/fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/pytorch_model.bin b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/pytorch_model.bin similarity index 100% rename from fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/pytorch_model.bin rename to fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/pytorch_model.bin diff --git a/fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/rng_state.pth b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/rng_state.pth similarity index 100% rename from fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/rng_state.pth rename to fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/rng_state.pth diff --git a/fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/scheduler.pt b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/scheduler.pt similarity index 100% rename from fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/scheduler.pt rename to fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/scheduler.pt diff --git a/fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/trainer_state.json b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/trainer_state.json similarity index 100% rename from fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/trainer_state.json rename to fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/trainer_state.json diff --git a/fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/training_args.bin b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/training_args.bin similarity index 100% rename from fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/training_args.bin rename to fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/training_args.bin diff --git a/geneformer/__init__.py b/geneformer/__init__.py index 38cfecec8ec7924d7c74bf195165417b4c891829..52d43619d06f2a7c019b480d1958a82d287d26ff 100644 --- a/geneformer/__init__.py +++ b/geneformer/__init__.py @@ -4,15 +4,10 @@ from pathlib import Path warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa # isort:skip -GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc104M.pkl" -TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary_gc104M.pkl" -ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict_gc104M.pkl" -ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict_gc104M.pkl" - -GENE_MEDIAN_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl" -TOKEN_DICTIONARY_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/token_dictionary_gc30M.pkl" -ENSEMBL_DICTIONARY_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl" -ENSEMBL_MAPPING_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/ensembl_mapping_dict_gc30M.pkl" +GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc95M.pkl" +TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary_gc95M.pkl" +ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict_gc95M.pkl" +ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict_gc95M.pkl" from . import ( collator_for_classification, diff --git a/geneformer/classifier.py b/geneformer/classifier.py index ee99a8a9c0ac4d53e56b18b01a11aa3b1eb5d5c5..b5ac161e461a014cce6df0d75262a1bc98e88259 100644 --- a/geneformer/classifier.py +++ b/geneformer/classifier.py @@ -48,13 +48,11 @@ import logging import os import pickle import subprocess -from packaging.version import parse from pathlib import Path import numpy as np import pandas as pd import seaborn as sns -import transformers from tqdm.auto import tqdm, trange from transformers import Trainer from transformers.training_args import TrainingArguments @@ -73,7 +71,6 @@ sns.set() logger = logging.getLogger(__name__) -transformers_version = parse(transformers.__version__) class Classifier: valid_option_dict = { @@ -92,7 +89,6 @@ class Classifier: "no_eval": {bool}, "stratify_splits_col": {None, str}, "forward_batch_size": {int}, - "model_version": {"V1", "V2"}, "token_dictionary_file": {None, str}, "nproc": {int}, "ngpu": {int}, @@ -116,7 +112,6 @@ class Classifier: stratify_splits_col=None, no_eval=False, forward_batch_size=100, - model_version="V2", token_dictionary_file=None, nproc=4, ngpu=1, @@ -193,9 +188,6 @@ class Classifier: | Otherwise, will perform eval during training. forward_batch_size : int | Batch size for forward pass (for evaluation, not training). - model_version : str - | To auto-select settings for model version other than current default. - | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells token_dictionary_file : None, str | Default is to use token dictionary file from Geneformer | Otherwise, will load custom gene token dictionary. @@ -230,16 +222,14 @@ class Classifier: self.stratify_splits_col = stratify_splits_col self.no_eval = no_eval self.forward_batch_size = forward_batch_size - self.model_version = model_version self.token_dictionary_file = token_dictionary_file self.nproc = nproc self.ngpu = ngpu - + if self.training_args is None: logger.warning( "Hyperparameter tuning is highly recommended for optimal results. " - "No training_args provided; using default hyperparameters. " - "Please note: these defaults are not recommended to be used uniformly across tasks." + "No training_args provided; using default hyperparameters." ) self.validate_options() @@ -254,10 +244,7 @@ class Classifier: ] = self.cell_state_dict["states"] # load token dictionary (Ensembl IDs:token) - if self.model_version == "V1": - from . import TOKEN_DICTIONARY_FILE_30M - self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M - elif self.token_dictionary_file is None: + if self.token_dictionary_file is None: self.token_dictionary_file = TOKEN_DICTIONARY_FILE with open(self.token_dictionary_file, "rb") as f: self.gene_token_dict = pickle.load(f) @@ -367,7 +354,6 @@ class Classifier: attr_to_balance=None, max_trials=100, pval_threshold=0.1, - id_class_dict_path=None, ): """ Prepare data for cell state or gene classification. @@ -410,10 +396,6 @@ class Classifier: pval_threshold : None, float | P-value threshold to use for attribute balancing across splits | E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance - id_class_dict_path : Path - | Path to *_id_class_dict.pkl from prior run of prepare_data to reuse for labeling new data - | Dictionary with keys being numeric class labels and values being original dataset class labels - | Note: only available for CellClassifiers """ if test_size is None: @@ -455,25 +437,14 @@ class Classifier: ) # rename cell state column to "label" data = cu.rename_cols(data, self.cell_state_dict["state_key"]) - - # convert classes to numerical labels and save as id_class_dict - if id_class_dict_path is not None: - with open(id_class_dict_path,"rb") as fp: - id_class_dict = pickle.load(fp) - else: - id_class_dict = None - data, id_class_dict = cu.label_classes( - self.classifier, data, self.cell_state_dict, self.nproc, id_class_dict, - ) - elif self.classifier == "gene": # convert classes to numerical labels and save as id_class_dict # of note, will label all genes in gene_class_dict # if (cross-)validating, genes will be relabeled in column "labels" for each split # at the time of training with Classifier.validate - data, id_class_dict = cu.label_classes( - self.classifier, data, self.gene_class_dict, self.nproc - ) + data, id_class_dict = cu.label_classes( + self.classifier, data, self.gene_class_dict, self.nproc + ) # save id_class_dict for future reference id_class_output_path = ( @@ -801,7 +772,7 @@ class Classifier: # 5-fold cross-validate num_cells = len(data) fifth_cells = int(np.floor(num_cells * 0.2)) - num_eval = int(min((self.eval_size * num_cells), fifth_cells)) + num_eval = min((self.eval_size * num_cells), fifth_cells) start = i * fifth_cells end = start + num_eval eval_indices = [j for j in range(start, end)] @@ -1085,8 +1056,6 @@ class Classifier: if eval_data is None: def_training_args["evaluation_strategy"] = "no" def_training_args["load_best_model_at_end"] = False - if transformers_version >= parse("4.46"): - def_training_args["eval_strategy"] = def_training_args.pop("evaluation_strategy") def_training_args.update( {"save_strategy": "epoch", "save_total_limit": 1} ) # only save last model for each run @@ -1258,8 +1227,6 @@ class Classifier: if eval_data is None: def_training_args["evaluation_strategy"] = "no" def_training_args["load_best_model_at_end"] = False - if transformers_version >= parse("4.46"): - def_training_args["eval_strategy"] = def_training_args.pop("evaluation_strategy") training_args_init = TrainingArguments(**def_training_args) if self.freeze_layers is not None: @@ -1313,7 +1280,6 @@ class Classifier: predict=False, output_directory=None, output_prefix=None, - predict_metadata=None, ): """ Evaluate the fine-tuned model. @@ -1339,11 +1305,9 @@ class Classifier: ##### Evaluate the model ##### labels = id_class_dict.keys() - - y_pred, y_true, logits_list, predict_metadata_all = eu.classifier_predict( - model, self.classifier, eval_data, self.forward_batch_size, self.gene_token_dict, predict_metadata + y_pred, y_true, logits_list = eu.classifier_predict( + model, self.classifier, eval_data, self.forward_batch_size ) - conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics( y_pred, y_true, logits_list, num_classes, labels ) @@ -1353,9 +1317,6 @@ class Classifier: "label_ids": y_true, "predictions": logits_list, } - if predict_metadata is not None: - pred_dict["prediction_metadata"] = predict_metadata_all - pred_dict_output_path = ( Path(output_directory) / f"{output_prefix}_pred_dict" ).with_suffix(".pkl") @@ -1376,7 +1337,6 @@ class Classifier: output_directory, output_prefix, predict=True, - predict_metadata=None, ): """ Evaluate the fine-tuned model. @@ -1396,8 +1356,6 @@ class Classifier: | Prefix for output files predict : bool | Whether or not to save eval predictions - predict_metadata : None | list - | Metadata labels to output with predictions (columns in test_data_file) """ # load numerical id to class dictionary (id:class) @@ -1410,15 +1368,6 @@ class Classifier: # load previously filtered and prepared data test_data = pu.load_and_filter(None, self.nproc, test_data_file) - if predict_metadata is not None: - absent_metadata = [] - for predict_metadata_x in predict_metadata: - if predict_metadata_x not in test_data.features.keys(): - absent_metadata += [predict_metadata_x] - if len(absent_metadata)>0: - logger.error(f"Following predict_metadata was not found as column in test_data_file: {absent_metadata}") - raise - # load previously fine-tuned model model = pu.load_model( self.model_type, @@ -1437,7 +1386,6 @@ class Classifier: predict=predict, output_directory=output_directory, output_prefix=output_prefix, - predict_metadata=predict_metadata, ) all_conf_mat_df = pd.DataFrame( diff --git a/geneformer/classifier_utils.py b/geneformer/classifier_utils.py index a30884c34629573b16c3951517d42f05e57ec108..d2da349a731bbeb4dc023b48a6bd283c7381e236 100644 --- a/geneformer/classifier_utils.py +++ b/geneformer/classifier_utils.py @@ -94,7 +94,7 @@ def remove_rare(data, rare_threshold, label, nproc): return data -def label_classes(classifier, data, gene_class_dict, nproc, id_class_dict=None): +def label_classes(classifier, data, gene_class_dict, nproc): if classifier == "cell": label_set = set(data["label"]) elif classifier == "gene": @@ -113,24 +113,15 @@ def label_classes(classifier, data, gene_class_dict, nproc, id_class_dict=None): ) raise - if id_class_dict is None: - class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))])) - id_class_dict = {v: k for k, v in class_id_dict.items()} - else: - class_id_dict = {v: k for k, v in id_class_dict.items()} - - if classifier == "gene": - inverse_gene_class_dict = {} - for key, value_list in gene_class_dict.items(): - for value in value_list: - inverse_gene_class_dict[value] = key + class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))])) + id_class_dict = {v: k for k, v in class_id_dict.items()} def classes_to_ids(example): if classifier == "cell": example["label"] = class_id_dict[example["label"]] elif classifier == "gene": example["labels"] = label_gene_classes( - example, class_id_dict, inverse_gene_class_dict + example, class_id_dict, gene_class_dict ) return example @@ -138,9 +129,9 @@ def label_classes(classifier, data, gene_class_dict, nproc, id_class_dict=None): return data, id_class_dict -def label_gene_classes(example, class_id_dict, inverse_gene_class_dict): +def label_gene_classes(example, class_id_dict, gene_class_dict): return [ - class_id_dict.get(inverse_gene_class_dict.get(token_id, -100), -100) + class_id_dict.get(gene_class_dict.get(token_id, -100), -100) for token_id in example["input_ids"] ] @@ -570,27 +561,6 @@ def compute_metrics(pred): return {"accuracy": acc, "macro_f1": macro_f1} -def robust_compute_objective(metrics: dict): - # tries both prefixed ("eval_") and raw metric names to support different transformers versions - metric_name = "macro_f1" - - # check for the prefixed version - prefixed_metric_name = f"eval_{metric_name}" - if prefixed_metric_name in metrics: - return metrics[prefixed_metric_name] - - # fall back to the raw name - elif metric_name in metrics: - return metrics[metric_name] - - # if neither is found, raise a clear error to help with debugging - raise KeyError( - f"Could not find '{prefixed_metric_name}' or '{metric_name}' in the reported metrics. " - f"Please check your `compute_metrics` function and `TrainingArguments`. " - f"Available metrics: {list(metrics.keys())}" - ) - - def get_default_train_args(model, classifier, data, output_dir): num_layers = pu.quant_layers(model) freeze_layers = 0 diff --git a/geneformer/collator_for_classification.py b/geneformer/collator_for_classification.py index a34dda86cb56021a982d76111971f315b8e33e05..297fa666dbf0daeaa94e2ca203ace5f98570a30e 100644 --- a/geneformer/collator_for_classification.py +++ b/geneformer/collator_for_classification.py @@ -26,10 +26,9 @@ LARGE_INTEGER = int( 1e20 ) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER -warnings.filterwarnings("ignore", message="To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach()", category=UserWarning, module="torch") - # precollator functions + class ExplicitEnum(Enum): """ Enum with more explicit error message for missing values. @@ -104,9 +103,6 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin): def pad_token_id(self): return self._pad_token_id - def save_pretrained(self, save_directory): - pass - def _get_padding_truncation_strategies( self, padding=True, @@ -645,8 +641,7 @@ class DataCollatorForGeneClassification(DataCollatorForTokenClassification): def __call__(self, features): batch = self._prepare_batch(features) - # batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()} - batch = {k: torch.tensor(v.clone().detach(), dtype=torch.int64) if isinstance(v, torch.Tensor) else torch.tensor(v, dtype=torch.int64) for k, v in batch.items()} + batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()} return batch diff --git a/geneformer/emb_extractor.py b/geneformer/emb_extractor.py index d46ad463a38896d7628d7b49249f000d9783a8e2..20c22b91004a9668141c1d0329f4526fd052b9e9 100644 --- a/geneformer/emb_extractor.py +++ b/geneformer/emb_extractor.py @@ -42,8 +42,6 @@ def get_embs( special_token=False, summary_stat=None, silent=False, - save_tdigest=False, - tdigest_path=None, ): model_input_size = pu.get_model_input_size(model) total_batch_length = len(filtered_input_data) @@ -182,18 +180,12 @@ def get_embs( # calculate summary stat embs from approximated tdigests elif summary_stat is not None: if emb_mode == "cell": - if save_tdigest: - with open(f"{tdigest_path}","wb") as fp: - pickle.dump(embs_tdigests, fp) if summary_stat == "mean": summary_emb_list = tdigest_mean(embs_tdigests, emb_dims) elif summary_stat == "median": summary_emb_list = tdigest_median(embs_tdigests, emb_dims) embs_stack = torch.tensor(summary_emb_list) elif emb_mode == "gene": - if save_tdigest: - with open(f"{tdigest_path}","wb") as fp: - pickle.dump(embs_tdigests_dict, fp) if summary_stat == "mean": [ update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims) @@ -260,7 +252,7 @@ def label_cell_embs(embs, downsampled_data, emb_labels): return embs_df -def label_gene_embs(embs, downsampled_data, token_gene_dict, gene_emb_style="mean_pool"): +def label_gene_embs(embs, downsampled_data, token_gene_dict): gene_set = { element for sublist in downsampled_data["input_ids"] for element in sublist } @@ -275,52 +267,25 @@ def label_gene_embs(embs, downsampled_data, token_gene_dict, gene_emb_style="mea ) for k in dict_i.keys(): gene_emb_dict[k].append(dict_i[k]) - if gene_emb_style != "all": - for k in gene_emb_dict.keys(): - gene_emb_dict[k] = ( - torch.squeeze(torch.mean(torch.stack(gene_emb_dict[k]), dim=0), dim=0) - .cpu() - .numpy() - ) - embs_df = pd.DataFrame(gene_emb_dict).T - else: - embs_df = dict_lol_to_df(gene_emb_dict) + for k in gene_emb_dict.keys(): + gene_emb_dict[k] = ( + torch.squeeze(torch.mean(torch.stack(gene_emb_dict[k]), dim=0), dim=0) + .cpu() + .numpy() + ) + embs_df = pd.DataFrame(gene_emb_dict).T embs_df.index = [token_gene_dict[token] for token in embs_df.index] return embs_df -def dict_lol_to_df(data_dict): - # save dictionary with values being list of equal-length lists as dataframe - df_data = [] - for key, list_of_lists in data_dict.items(): - for i, sublist in enumerate(list_of_lists): - row_data = [key, i] + sublist.tolist() - df_data.append(row_data) - - # determine column names based on the length of sublists - # assuming all sublists have the same length - num_columns_from_sublist = len(list(data_dict.values())[0][0]) - column_names = ['Gene', 'Identifier'] + [f'{j}' for j in range(num_columns_from_sublist)] - - # create the dataframe - df = pd.DataFrame(df_data, columns=column_names) - - # set 'Gene' as the index - df = df.set_index('Gene') - - return df - -def plot_umap(embs_df, emb_dims, labels_clean, output_prefix, output_directory, kwargs_dict, seed=0): + +def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict, seed=0): only_embs_df = embs_df.iloc[:, :emb_dims] only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str) only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype( str ) vars_dict = {"embs": only_embs_df.columns} - - obs_dict = {"cell_id": list(only_embs_df.index)} - for label_i in labels_clean: - obs_dict[label_i] = list(embs_df[label_i]) - + obs_dict = {"cell_id": list(only_embs_df.index), f"{label}": list(embs_df[label])} adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict) sc.tl.pca(adata, svd_solver="arpack") sc.pp.neighbors(adata, random_state=seed) @@ -331,26 +296,21 @@ def plot_umap(embs_df, emb_dims, labels_clean, output_prefix, output_directory, if kwargs_dict is not None: default_kwargs_dict.update(kwargs_dict) - for label_i in labels_clean: - output_prefix_label = output_prefix + f"_umap_{label_i}" - output_file = ( - Path(output_directory) / output_prefix_label - ).with_suffix(".pdf") - - cats = set(embs_df[label_i]) - - with plt.rc_context(): - ax = sc.pl.umap(adata, color=label_i, show=False, **default_kwargs_dict) - ax.legend( - markerscale=2, - frameon=False, - loc="center left", - bbox_to_anchor=(1, 0.5), - ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3), - ) - plt.show() - plt.savefig(output_file, bbox_inches="tight") - + cats = set(embs_df[label]) + + with plt.rc_context(): + ax = sc.pl.umap(adata, color=label, show=False, **default_kwargs_dict) + ax.legend( + markerscale=2, + frameon=False, + loc="center left", + bbox_to_anchor=(1, 0.5), + ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3), + ) + plt.show() + plt.savefig(output_file, bbox_inches="tight") + + def gen_heatmap_class_colors(labels, df): pal = sns.cubehelix_palette( len(Counter(labels).keys()), @@ -435,14 +395,13 @@ class EmbExtractor: "num_classes": {int}, "emb_mode": {"cls", "cell", "gene"}, "cell_emb_style": {"mean_pool"}, - "gene_emb_style": {"mean_pool", "all"}, + "gene_emb_style": {"mean_pool"}, "filter_data": {None, dict}, "max_ncells": {None, int}, "emb_layer": {-1, 0}, "emb_label": {None, list}, "labels_to_plot": {None, list}, "forward_batch_size": {int}, - "model_version": {"V1", "V2"}, "token_dictionary_file": {None, str}, "nproc": {int}, "summary_stat": {None, "mean", "median", "exact_mean", "exact_median"}, @@ -452,7 +411,7 @@ class EmbExtractor: self, model_type="Pretrained", num_classes=0, - emb_mode="cls", + emb_mode="cell", cell_emb_style="mean_pool", gene_emb_style="mean_pool", filter_data=None, @@ -463,8 +422,6 @@ class EmbExtractor: forward_batch_size=100, nproc=4, summary_stat=None, - save_tdigest=False, - model_version="V2", token_dictionary_file=None, ): """ @@ -472,8 +429,8 @@ class EmbExtractor: **Parameters:** - model_type : {"Pretrained", "GeneClassifier", "CellClassifier", "Pretrained-Quantized"} - | Whether model is the pretrained Geneformer (full or quantized) or a fine-tuned gene or cell classifier. + model_type : {"Pretrained", "GeneClassifier", "CellClassifier"} + | Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier. num_classes : int | If model is a gene or cell classifier, specify number of classes it was trained to classify. | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier. @@ -483,9 +440,9 @@ class EmbExtractor: cell_emb_style : {"mean_pool"} | Method for summarizing cell embeddings if not using CLS token. | Currently only option is mean pooling of gene embeddings for given cell. - gene_emb_style : {"mean_pool", "all} + gene_emb_style : "mean_pool" | Method for summarizing gene embeddings. - | Currently only option is returning all or mean pooling of contextual gene embeddings for given gene. + | Currently only option is mean pooling of contextual gene embeddings for given gene. filter_data : None, dict | Default is to extract embeddings from all input data. | Otherwise, dictionary specifying .dataset column name and list of values to filter by. @@ -515,12 +472,6 @@ class EmbExtractor: | If mean or median, outputs only approximated mean or median embedding of input data. | Non-exact recommended if encountering memory constraints while generating goal embedding positions. | Non-exact is slower but more memory-efficient. - save_tdigest : bool - | Whether to save a dictionary of tdigests for each gene and embedding dimension - | Only applies when summary_stat is not None - model_version : str - | To auto-select settings for model version other than current default. - | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells token_dictionary_file : Path | Default is the Geneformer token dictionary | Path to pickle file containing token dictionary (Ensembl ID:token). @@ -551,7 +502,6 @@ class EmbExtractor: self.emb_layer = emb_layer self.emb_label = emb_label self.labels_to_plot = labels_to_plot - self.model_version = model_version self.token_dictionary_file = token_dictionary_file self.forward_batch_size = forward_batch_size self.nproc = nproc @@ -561,29 +511,13 @@ class EmbExtractor: else: self.summary_stat = summary_stat self.exact_summary_stat = None - self.save_tdigest = save_tdigest self.validate_options() - if (summary_stat is None) and (save_tdigest is True): - logger.warning( - "tdigests will not be saved since summary_stat is None." - ) - save_tdigest = False - - if self.model_version == "V1": - from . import TOKEN_DICTIONARY_FILE_30M - self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M - if self.emb_mode == "cls": - self.emb_mode = "cell" - logger.warning( - "model_version selected as V1 so changing emb_mode from 'cls' to 'cell' as V1 models do not have a token." - ) - # load token dictionary (Ensembl IDs:token) if self.token_dictionary_file is None: - self.token_dictionary_file = TOKEN_DICTIONARY_FILE - with open(self.token_dictionary_file, "rb") as f: + token_dictionary_file = TOKEN_DICTIONARY_FILE + with open(token_dictionary_file, "rb") as f: self.gene_token_dict = pickle.load(f) self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()} @@ -662,12 +596,6 @@ class EmbExtractor: filtered_input_data = pu.load_and_filter( self.filter_data, self.nproc, input_data_file ) - - # Check to make sure that all the labels exist in the tokenized data: - if self.emb_label is not None: - for label in self.emb_label: - assert label in filtered_input_data.features.keys(), f"Attribute `{label}` not present in dataset features" - if cell_state is not None: filtered_input_data = pu.filter_by_dict( filtered_input_data, cell_state, self.nproc @@ -677,10 +605,6 @@ class EmbExtractor: self.model_type, self.num_classes, model_directory, mode="eval" ) layer_to_quant = pu.quant_layers(model) + self.emb_layer - if self.save_tdigest: - tdigest_path = (Path(output_directory) / f"{output_prefix}_tdigest").with_suffix(".pkl") - else: - tdigest_path = None embs = get_embs( model=model, filtered_input_data=downsampled_data, @@ -690,8 +614,6 @@ class EmbExtractor: forward_batch_size=self.forward_batch_size, token_gene_dict=self.token_gene_dict, summary_stat=self.summary_stat, - save_tdigest=self.save_tdigest, - tdigest_path=tdigest_path, ) if self.emb_mode == "cell": @@ -701,7 +623,7 @@ class EmbExtractor: embs_df = pd.DataFrame(embs.cpu().numpy()).T elif self.emb_mode == "gene": if self.summary_stat is None: - embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict, self.gene_emb_style) + embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict) elif self.summary_stat is not None: embs_df = pd.DataFrame(embs).T embs_df.index = [self.token_gene_dict[token] for token in embs_df.index] @@ -717,14 +639,14 @@ class EmbExtractor: embs = embs.mean(dim=0) emb_dims = pu.get_model_emb_dims(model) embs_df = pd.DataFrame( - embs_df.iloc[:, 0 : emb_dims].mean(axis="rows"), + embs_df[0 : emb_dims - 1].mean(axis="rows"), columns=[self.exact_summary_stat], ).T elif self.exact_summary_stat == "exact_median": embs = torch.median(embs, dim=0)[0] emb_dims = pu.get_model_emb_dims(model) embs_df = pd.DataFrame( - embs_df.iloc[:, 0 : emb_dims].median(axis="rows"), + embs_df[0 : emb_dims - 1].median(axis="rows"), columns=[self.exact_summary_stat], ).T @@ -797,12 +719,6 @@ class EmbExtractor: ) raise - if self.emb_label is not None: - logger.error( - "For extracting state embs, emb_label should be None since labels are based on state embs dict keys." - ) - raise - state_embs_dict = dict() state_key = cell_states_to_model["state_key"] for k, v in cell_states_to_model.items(): @@ -885,14 +801,14 @@ class EmbExtractor: raise if max_ncells_to_plot is not None: - if self.max_ncells is not None: - if max_ncells_to_plot > self.max_ncells: - max_ncells_to_plot = self.max_ncells - logger.warning( - "max_ncells_to_plot must be <= max_ncells. " - f"Changing max_ncells_to_plot to {self.max_ncells}." - ) - embs = embs.sample(max_ncells_to_plot, axis=0) + if max_ncells_to_plot > self.max_ncells: + max_ncells_to_plot = self.max_ncells + logger.warning( + "max_ncells_to_plot must be <= max_ncells. " + f"Changing max_ncells_to_plot to {self.max_ncells}." + ) + elif max_ncells_to_plot < self.max_ncells: + embs = embs.sample(max_ncells_to_plot, axis=0) if self.emb_label is None: label_len = 0 @@ -913,9 +829,12 @@ class EmbExtractor: f"Label {label} from labels_to_plot " f"not present in provided embeddings dataframe." ) - - labels_clean = [label for label in self.labels_to_plot if label in emb_labels] - plot_umap(embs, emb_dims, labels_clean, output_prefix, output_directory, kwargs_dict) + continue + output_prefix_label = output_prefix + f"_umap_{label}" + output_file = ( + Path(output_directory) / output_prefix_label + ).with_suffix(".pdf") + plot_umap(embs, emb_dims, label, output_file, kwargs_dict) if plot_style == "heatmap": for label in self.labels_to_plot: diff --git a/geneformer/ensembl_mapping_dict_gc104M.pkl b/geneformer/ensembl_mapping_dict_gc95M.pkl similarity index 100% rename from geneformer/ensembl_mapping_dict_gc104M.pkl rename to geneformer/ensembl_mapping_dict_gc95M.pkl diff --git a/geneformer/evaluation_utils.py b/geneformer/evaluation_utils.py index 88f850cda75150d997f4fb6b4c13f2b13b2e2072..b42833785819a08d9afc1cdb84a210c46a9e94ea 100644 --- a/geneformer/evaluation_utils.py +++ b/geneformer/evaluation_utils.py @@ -8,7 +8,6 @@ import numpy as np import pandas as pd import seaborn as sns import torch -import datasets from datasets.utils.logging import disable_progress_bar, enable_progress_bar from sklearn import preprocessing from sklearn.metrics import ( @@ -21,15 +20,20 @@ from sklearn.metrics import ( ) from tqdm.auto import trange +from . import TOKEN_DICTIONARY_FILE from .emb_extractor import make_colorbar logger = logging.getLogger(__name__) -def preprocess_classifier_batch(cell_batch, max_len, label_name, gene_token_dict): +def preprocess_classifier_batch(cell_batch, max_len, label_name): if max_len is None: max_len = max([len(i) for i in cell_batch["input_ids"]]) + # load token dictionary (Ensembl IDs:token) + with open(TOKEN_DICTIONARY_FILE, "rb") as f: + gene_token_dict = pickle.load(f) + def pad_label_example(example): example[label_name] = np.pad( example[label_name], @@ -77,7 +81,7 @@ def py_softmax(vector): return e / e.sum() -def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene_token_dict, predict_metadata=None): +def classifier_predict(model, classifier_type, evalset, forward_batch_size): if classifier_type == "gene": label_name = "labels" elif classifier_type == "cell": @@ -85,14 +89,6 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene predict_logits = [] predict_labels = [] - - predict_metadata_all = None - - if predict_metadata is not None: - predict_metadata_all = dict() - for metadata_name in predict_metadata: - predict_metadata_all[metadata_name] = [] - model.eval() # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims @@ -107,21 +103,11 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene for i in trange(0, evalset_len, forward_batch_size): max_range = min(i + forward_batch_size, evalset_len) batch_evalset = evalset.select([i for i in range(i, max_range)]) - - if predict_metadata is not None: - for metadata_name in predict_metadata: - predict_metadata_all[metadata_name] += batch_evalset[metadata_name] - padded_batch = preprocess_classifier_batch( - batch_evalset, max_evalset_len, label_name, gene_token_dict + batch_evalset, max_evalset_len, label_name ) - padded_batch.set_format(type="torch") - # For datasets>=4.0.0, convert to dict to avoid format issues - if int(datasets.__version__.split(".")[0]) >= 4: - padded_batch = padded_batch[:] - input_data_batch = padded_batch["input_ids"] attn_msk_batch = padded_batch["attention_mask"] label_batch = padded_batch[label_name] @@ -148,8 +134,7 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene y_pred = [vote(item[0]) for item in logit_label_paired] y_true = [item[1] for item in logit_label_paired] logits_list = [item[0] for item in logit_label_paired] - - return y_pred, y_true, logits_list, predict_metadata_all + return y_pred, y_true, logits_list def get_metrics(y_pred, y_true, logits_list, num_classes, labels): @@ -197,18 +182,14 @@ def plot_ROC(roc_metric_dict, model_style_dict, title, output_dir, output_prefix for model_name in roc_metric_dict.keys(): mean_fpr = roc_metric_dict[model_name]["mean_fpr"] mean_tpr = roc_metric_dict[model_name]["mean_tpr"] + roc_auc = roc_metric_dict[model_name]["roc_auc"] + roc_auc_sd = roc_metric_dict[model_name]["roc_auc_sd"] color = model_style_dict[model_name]["color"] linestyle = model_style_dict[model_name]["linestyle"] - if "roc_auc" not in roc_metric_dict[model_name].keys(): - all_roc_auc = roc_metric_dict[model_name]["all_roc_auc"] - label = f"{model_name} (AUC {all_roc_auc:0.2f})" + if len(roc_metric_dict[model_name]["all_roc_auc"]) > 1: + label = f"{model_name} (AUC {roc_auc:0.2f} $\pm$ {roc_auc_sd:0.2f})" else: - roc_auc = roc_metric_dict[model_name]["roc_auc"] - roc_auc_sd = roc_metric_dict[model_name]["roc_auc_sd"] - if len(roc_metric_dict[model_name]["all_roc_auc"]) > 1: - label = f"{model_name} (AUC {roc_auc:0.2f} $\pm$ {roc_auc_sd:0.2f})" - else: - label = f"{model_name} (AUC {roc_auc:0.2f})" + label = f"{model_name} (AUC {roc_auc:0.2f})" plt.plot( mean_fpr, mean_tpr, color=color, linestyle=linestyle, lw=lw, label=label ) diff --git a/geneformer/gene_median_dictionary_gc104M.pkl b/geneformer/gene_median_dictionary_gc95M.pkl similarity index 100% rename from geneformer/gene_median_dictionary_gc104M.pkl rename to geneformer/gene_median_dictionary_gc95M.pkl diff --git a/geneformer/gene_name_id_dict_gc104M.pkl b/geneformer/gene_name_id_dict_gc104M.pkl deleted file mode 100644 index f98c94dd0c7ff50b6b74691d75c66be5affc9fa1..0000000000000000000000000000000000000000 --- a/geneformer/gene_name_id_dict_gc104M.pkl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:fabfa0c2f49c598c59ae432a32c3499a5908c033756c663b5e0cddf58deea8e1 -size 1660882 diff --git a/geneformer/gene_name_id_dict_gc95M.pkl b/geneformer/gene_name_id_dict_gc95M.pkl new file mode 100644 index 0000000000000000000000000000000000000000..f397337d26d3eddf66cb89183047a9e38cea5988 --- /dev/null +++ b/geneformer/gene_name_id_dict_gc95M.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b0fd0521406ed18b2e341ef0acb5f53aa1a62457a07ca5840e1c142f46dd326 +size 2038812 diff --git a/geneformer/in_silico_perturber.py b/geneformer/in_silico_perturber.py index 35244c22237abd7e24938bb4324c2129551634a7..3e0dc7822b00b0cc2050a025aef824f94c7806ca 100644 --- a/geneformer/in_silico_perturber.py +++ b/geneformer/in_silico_perturber.py @@ -40,7 +40,7 @@ import pickle from collections import defaultdict import torch -from datasets import Dataset +from datasets import Dataset, disable_progress_bars from multiprocess import set_start_method from tqdm.auto import trange @@ -48,9 +48,7 @@ from . import TOKEN_DICTIONARY_FILE from . import perturber_utils as pu from .emb_extractor import get_embs -import datasets -datasets.logging.disable_progress_bar() - +disable_progress_bars() logger = logging.getLogger(__name__) @@ -62,7 +60,7 @@ class InSilicoPerturber: "genes_to_perturb": {"all", list}, "combos": {0, 1}, "anchor_gene": {None, str}, - "model_type": {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "Pretrained-Quantized", "MTLCellClassifier-Quantized"}, + "model_type": {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "MTLCellClassifier-Quantized"}, "num_classes": {int}, "emb_mode": {"cls", "cell", "cls_and_gene", "cell_and_gene"}, "cell_emb_style": {"mean_pool"}, @@ -72,7 +70,6 @@ class InSilicoPerturber: "max_ncells": {None, int}, "cell_inds_to_perturb": {"all", dict}, "emb_layer": {-1, 0}, - "model_version": {"V1", "V2"}, "token_dictionary_file": {None, str}, "forward_batch_size": {int}, "nproc": {int}, @@ -87,7 +84,7 @@ class InSilicoPerturber: anchor_gene=None, model_type="Pretrained", num_classes=0, - emb_mode="cls", + emb_mode="cell", cell_emb_style="mean_pool", filter_data=None, cell_states_to_model=None, @@ -97,7 +94,6 @@ class InSilicoPerturber: emb_layer=-1, forward_batch_size=100, nproc=4, - model_version="V2", token_dictionary_file=None, clear_mem_ncells=1000, ): @@ -134,7 +130,7 @@ class InSilicoPerturber: | ENSEMBL ID of gene to use as anchor in combination perturbations. | For example, if combos=1 and anchor_gene="ENSG00000148400": | anchor gene will be perturbed in combination with each other gene. - model_type : {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "Pretrained-Quantized", "MTLCellClassifier-Quantized"} + model_type : {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "MTLCellClassifier-Quantized"} | Whether model is the pretrained Geneformer or a fine-tuned gene, cell, or multitask cell classifier (+/- 8bit quantization). num_classes : int | If model is a gene or cell classifier, specify number of classes it was trained to classify. @@ -186,9 +182,6 @@ class InSilicoPerturber: | Batch size for forward pass. nproc : int | Number of CPU processes to use. - model_version : str - | To auto-select settings for model version other than current default. - | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells token_dictionary_file : Path | Path to pickle file containing token dictionary (Ensembl ID:token). clear_mem_ncells : int @@ -229,30 +222,15 @@ class InSilicoPerturber: self.emb_layer = emb_layer self.forward_batch_size = forward_batch_size self.nproc = nproc - self.model_version = model_version self.token_dictionary_file = token_dictionary_file - self.clear_mem_ncells = clear_mem_ncells + self.clear_mem_ncells = clear_mem_ncells self.validate_options() - if self.model_version == "V1": - from . import TOKEN_DICTIONARY_FILE_30M - self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M - if self.emb_mode == "cls": - self.emb_mode = "cell" - logger.warning( - "model_version selected as V1 so changing emb_mode from 'cls' to 'cell' as V1 models do not have a token." - ) - if self.emb_mode == "cls_and_gene": - self.emb_mode = "cell_and_gene" - logger.warning( - "model_version selected as V1 so changing emb_mode from 'cls_and_gene' to 'cell_and_gene' as V1 models do not have a token." - ) - # load token dictionary (Ensembl IDs:token) if self.token_dictionary_file is None: - self.token_dictionary_file = TOKEN_DICTIONARY_FILE - with open(self.token_dictionary_file, "rb") as f: + token_dictionary_file = TOKEN_DICTIONARY_FILE + with open(token_dictionary_file, "rb") as f: self.gene_token_dict = pickle.load(f) self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()} @@ -816,8 +794,6 @@ class InSilicoPerturber: return example total_batch_length = len(filtered_input_data) - - if self.cell_states_to_model is None: cos_sims_dict = defaultdict(list) else: @@ -902,7 +878,7 @@ class InSilicoPerturber: ) ##### CLS and Gene Embedding Mode ##### - elif self.emb_mode == "cls_and_gene": + elif self.emb_mode == "cls_and_gene": full_original_emb = get_embs( model, minibatch, @@ -915,7 +891,6 @@ class InSilicoPerturber: silent=True, ) indices_to_perturb = perturbation_batch["perturb_index"] - # remove indices that were perturbed original_emb = pu.remove_perturbed_indices_set( full_original_emb, @@ -924,7 +899,6 @@ class InSilicoPerturber: self.tokens_to_perturb, minibatch["length"], ) - full_perturbation_emb = get_embs( model, perturbation_batch, @@ -936,7 +910,7 @@ class InSilicoPerturber: summary_stat=None, silent=True, ) - + # remove special tokens and padding original_emb = original_emb[:, 1:-1, :] if self.perturb_type == "overexpress": @@ -947,25 +921,9 @@ class InSilicoPerturber: perturbation_emb = full_perturbation_emb[ :, 1 : max(perturbation_batch["length"]) - 1, : ] - - n_perturbation_genes = perturbation_emb.size()[1] - # truncate the original embedding as necessary - if self.perturb_type == "overexpress": - def calc_perturbation_length(ids): - if ids == [-100]: - return 0 - else: - return len(ids) - - max_tensor_size = max([length - calc_perturbation_length(ids) - 2 for length, ids in zip(minibatch["length"], indices_to_perturb)]) + n_perturbation_genes = perturbation_emb.size()[1] - max_n_overflow = max(minibatch["n_overflow"]) - if max_n_overflow > 0 and perturbation_emb.size()[1] < original_emb.size()[1]: - original_emb = original_emb[:, 0 : perturbation_emb.size()[1], :] - elif perturbation_emb.size()[1] < original_emb.size()[1]: - original_emb = original_emb[:, 0:max_tensor_size, :] - gene_cos_sims = pu.quant_cos_sims( perturbation_emb, original_emb, diff --git a/geneformer/in_silico_perturber_stats.py b/geneformer/in_silico_perturber_stats.py index 40c2d5fffbd87de84b294a0a5d608391d07b5128..e694d80f3f700d18c61022f872b25bcd8fb50756 100644 --- a/geneformer/in_silico_perturber_stats.py +++ b/geneformer/in_silico_perturber_stats.py @@ -640,16 +640,10 @@ def isp_stats_mixture_model(cos_sims_df, dict_list, combos, anchor_token): cos_sims_full_df = pd.concat([cos_sims_full_df, cos_sims_df_i]) # quantify number of detections of each gene - if anchor_token is None: - cos_sims_full_df["N_Detections"] = [ - n_detections(i, dict_list, "cell", anchor_token) - for i in cos_sims_full_df["Gene"] - ] - else: - cos_sims_full_df["N_Detections"] = [ - n_detections(i, dict_list, "gene", anchor_token) - for i in cos_sims_full_df["Gene"] - ] + cos_sims_full_df["N_Detections"] = [ + n_detections(i, dict_list, "gene", anchor_token) + for i in cos_sims_full_df["Gene"] + ] if combos == 0: cos_sims_full_df = cos_sims_full_df.sort_values( @@ -676,7 +670,6 @@ class InSilicoPerturberStats: "anchor_gene": {None, str}, "cell_states_to_model": {None, dict}, "pickle_suffix": {None, str}, - "model_version": {"V1", "V2"}, } def __init__( @@ -687,7 +680,6 @@ class InSilicoPerturberStats: anchor_gene=None, cell_states_to_model=None, pickle_suffix="_raw.pickle", - model_version="V2", token_dictionary_file=TOKEN_DICTIONARY_FILE, gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE, ): @@ -715,7 +707,7 @@ class InSilicoPerturberStats: | analyzes data for anchor gene perturbed in combination with each other gene. | However, if combos=0 and anchor_gene="ENSG00000136574": | analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene. - cell_states_to_model : None, dict + cell_states_to_model: None, dict | Cell states to model if testing perturbations that achieve goal state change. | Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states | state_key: key specifying name of column in .dataset that defines the start/goal states @@ -726,12 +718,6 @@ class InSilicoPerturberStats: | "start_state": "dcm", | "goal_state": "nf", | "alt_states": ["hcm", "other1", "other2"]} - pickle_suffix : None, str - | Suffix to subselect intermediate raw files for analysis. - | Default output of InSilicoPerturber uses suffix "_raw.pickle". - model_version : str - | To auto-select settings for model version other than current default. - | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells token_dictionary_file : Path | Path to pickle file containing token dictionary (Ensembl ID:token). gene_name_id_dictionary_file : Path @@ -744,15 +730,9 @@ class InSilicoPerturberStats: self.anchor_gene = anchor_gene self.cell_states_to_model = cell_states_to_model self.pickle_suffix = pickle_suffix - self.model_version = model_version self.validate_options() - if self.model_version == "V1": - from . import ENSEMBL_DICTIONARY_FILE_30M, TOKEN_DICTIONARY_FILE_30M - token_dictionary_file=TOKEN_DICTIONARY_FILE_30M - gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE_30M - # load token dictionary (Ensembl IDs:token) with open(token_dictionary_file, "rb") as f: self.gene_token_dict = pickle.load(f) diff --git a/geneformer/mtl/__init__.py b/geneformer/mtl/__init__.py index 5f5235c230cb0d1c51e2823e5b92118bb480dba9..06788a56ac11397d1698a74381d466b7b7bd98b7 100644 --- a/geneformer/mtl/__init__.py +++ b/geneformer/mtl/__init__.py @@ -1,4 +1 @@ -# ruff: noqa: F401 - -from . import eval_utils -from . import utils \ No newline at end of file +# ruff: noqa: F401 \ No newline at end of file diff --git a/geneformer/mtl/collators.py b/geneformer/mtl/collators.py index b575b23ba592a6f552928b306c7882de8dbe7f4f..63546f93a05c857781198be88de027f5fb9e827f 100644 --- a/geneformer/mtl/collators.py +++ b/geneformer/mtl/collators.py @@ -1,8 +1,8 @@ # imports import torch import pickle -from geneformer.collator_for_classification import DataCollatorForGeneClassification -from geneformer import TOKEN_DICTIONARY_FILE +from ..collator_for_classification import DataCollatorForGeneClassification +from .. import TOKEN_DICTIONARY_FILE """Geneformer collator for multi-task cell classification.""" diff --git a/geneformer/mtl/data.py b/geneformer/mtl/data.py index 7512f95a98a58e5e2d6076fa863bd628b0a24302..119f775efc757b26d1fc76ee0d3e9c48ce34db8e 100644 --- a/geneformer/mtl/data.py +++ b/geneformer/mtl/data.py @@ -1,237 +1,150 @@ import os -import pickle -import torch -from torch.utils.data import DataLoader, Dataset -from datasets import load_from_disk from .collators import DataCollatorForMultitaskCellClassification +from .imports import * -class StreamingMultiTaskDataset(Dataset): - - def __init__(self, dataset_path, config, is_test=False, dataset_type=""): - """Initialize the streaming dataset.""" - self.dataset = load_from_disk(dataset_path) - self.config = config - self.is_test = is_test - self.dataset_type = dataset_type - self.cell_id_mapping = {} - - # Setup task and column mappings - self.task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))] - self.task_to_column = dict(zip(self.task_names, config["task_columns"])) - config["task_names"] = self.task_names - - # Check if unique_cell_id column exists in the dataset - self.has_unique_cell_ids = "unique_cell_id" in self.dataset.column_names - print(f"{'Found' if self.has_unique_cell_ids else 'No'} unique_cell_id column in {dataset_type} dataset") - - # Setup label mappings - self.label_mappings_path = os.path.join( - config["results_dir"], - f"task_label_mappings{'_val' if dataset_type == 'validation' else ''}.pkl" - ) - +def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""): + try: + dataset = load_from_disk(dataset_path) + + task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))] + task_to_column = dict(zip(task_names, config["task_columns"])) + config["task_names"] = task_names + if not is_test: - self._validate_columns() - self.task_label_mappings, self.num_labels_list = self._create_label_mappings() - self._save_label_mappings() - else: - # Load existing mappings for test data - self.task_label_mappings = self._load_label_mappings() - self.num_labels_list = [len(mapping) for mapping in self.task_label_mappings.values()] - - def _validate_columns(self): - """Ensures required columns are present in the dataset.""" - missing_columns = [col for col in self.task_to_column.values() - if col not in self.dataset.column_names] - if missing_columns: - raise KeyError( - f"Missing columns in {self.dataset_type} dataset: {missing_columns}. " - f"Available columns: {self.dataset.column_names}" - ) - - def _create_label_mappings(self): - """Creates label mappings for the dataset.""" + available_columns = set(dataset.column_names) + for column in task_to_column.values(): + if column not in available_columns: + raise KeyError( + f"Column {column} not found in the dataset. Available columns: {list(available_columns)}" + ) + + label_mappings = {} task_label_mappings = {} + cell_id_mapping = {} num_labels_list = [] - - for task, column in self.task_to_column.items(): - unique_values = sorted(set(self.dataset[column])) - mapping = {label: idx for idx, label in enumerate(unique_values)} - task_label_mappings[task] = mapping - num_labels_list.append(len(unique_values)) - - return task_label_mappings, num_labels_list - - def _save_label_mappings(self): - """Saves label mappings to a pickle file.""" - with open(self.label_mappings_path, "wb") as f: - pickle.dump(self.task_label_mappings, f) - - def _load_label_mappings(self): - """Loads label mappings from a pickle file.""" - with open(self.label_mappings_path, "rb") as f: - return pickle.load(f) - - def __len__(self): - return len(self.dataset) - - def __getitem__(self, idx): - record = self.dataset[idx] - - # Store cell ID mapping - if self.has_unique_cell_ids: - unique_cell_id = record["unique_cell_id"] - self.cell_id_mapping[idx] = unique_cell_id - else: - self.cell_id_mapping[idx] = f"cell_{idx}" - - # Create transformed record - transformed_record = { - "input_ids": torch.tensor(record["input_ids"], dtype=torch.long), - "cell_id": idx, - } - - # Add labels - if not self.is_test: - label_dict = { - task: self.task_label_mappings[task][record[column]] - for task, column in self.task_to_column.items() - } + + # Load or create task label mappings + if not is_test: + for task, column in task_to_column.items(): + unique_values = sorted(set(dataset[column])) # Ensure consistency + label_mappings[column] = { + label: idx for idx, label in enumerate(unique_values) + } + task_label_mappings[task] = label_mappings[column] + num_labels_list.append(len(unique_values)) + + # Print the mappings for each task with dataset type prefix + for task, mapping in task_label_mappings.items(): + print( + f"{dataset_type.capitalize()} mapping for {task}: {mapping}" + ) # sanity check, for train/validation splits + + # Save the task label mappings as a pickle file + with open(f"{config['results_dir']}/task_label_mappings.pkl", "wb") as f: + pickle.dump(task_label_mappings, f) else: - label_dict = {task: -1 for task in self.config["task_names"]} - - transformed_record["label"] = label_dict - - return transformed_record + # Load task label mappings from pickle file for test data + with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f: + task_label_mappings = pickle.load(f) + + # Infer num_labels_list from task_label_mappings + for task, mapping in task_label_mappings.items(): + num_labels_list.append(len(mapping)) + + # Store unique cell IDs in a separate dictionary + for idx, record in enumerate(dataset): + cell_id = record.get("unique_cell_id", idx) + cell_id_mapping[idx] = cell_id + + # Transform records to the desired format + transformed_dataset = [] + for idx, record in enumerate(dataset): + transformed_record = {} + transformed_record["input_ids"] = torch.tensor( + record["input_ids"], dtype=torch.long + ) + # Use index-based cell ID for internal tracking + transformed_record["cell_id"] = idx -def get_data_loader(dataset, batch_size, sampler=None, shuffle=True): - """Create a DataLoader with the given dataset and parameters.""" - return DataLoader( - dataset, - batch_size=batch_size, - sampler=sampler, - shuffle=shuffle if sampler is None else False, - num_workers=0, - pin_memory=True, - collate_fn=DataCollatorForMultitaskCellClassification(), - ) + if not is_test: + # Prepare labels + label_dict = {} + for task, column in task_to_column.items(): + label_value = record[column] + label_index = task_label_mappings[task][label_value] + label_dict[task] = label_index + transformed_record["label"] = label_dict + else: + # Create dummy labels for test data + label_dict = {task: -1 for task in config["task_names"]} + transformed_record["label"] = label_dict + transformed_dataset.append(transformed_record) -def prepare_data_loaders(config, include_test=False): - """Prepare data loaders for training, validation, and optionally test.""" - result = {} - - # Process train data - train_dataset = StreamingMultiTaskDataset( - config["train_path"], - config, - dataset_type="train" - ) - result["train_loader"] = get_data_loader(train_dataset, config["batch_size"]) - - # Store the cell ID mapping from the dataset - result["train_cell_mapping"] = {k: v for k, v in train_dataset.cell_id_mapping.items()} - print(f"Collected {len(result['train_cell_mapping'])} cell IDs from training dataset") - - result["num_labels_list"] = train_dataset.num_labels_list - - # Process validation data - val_dataset = StreamingMultiTaskDataset( - config["val_path"], - config, - dataset_type="validation" - ) - result["val_loader"] = get_data_loader(val_dataset, config["batch_size"]) - - # Store the complete cell ID mapping for validation - for idx in range(len(val_dataset)): - _ = val_dataset[idx] - - result["val_cell_mapping"] = {k: v for k, v in val_dataset.cell_id_mapping.items()} - print(f"Collected {len(result['val_cell_mapping'])} cell IDs from validation dataset") - - # Validate label mappings - validate_label_mappings(config) - - # Process test data if requested - if include_test and "test_path" in config: - test_dataset = StreamingMultiTaskDataset( - config["test_path"], - config, - is_test=True, - dataset_type="test" - ) - result["test_loader"] = get_data_loader(test_dataset, config["batch_size"]) - - for idx in range(len(test_dataset)): - _ = test_dataset[idx] - - result["test_cell_mapping"] = {k: v for k, v in test_dataset.cell_id_mapping.items()} - print(f"Collected {len(result['test_cell_mapping'])} cell IDs from test dataset") - - return result - - -def validate_label_mappings(config): - """Ensures train and validation label mappings are consistent.""" - train_mappings_path = os.path.join(config["results_dir"], "task_label_mappings.pkl") - val_mappings_path = os.path.join(config["results_dir"], "task_label_mappings_val.pkl") - - with open(train_mappings_path, "rb") as f: - train_mappings = pickle.load(f) - - with open(val_mappings_path, "rb") as f: - val_mappings = pickle.load(f) - - for task_name in config["task_names"]: - if train_mappings[task_name] != val_mappings[task_name]: - raise ValueError( - f"Mismatch in label mappings for task '{task_name}'.\n" - f"Train Mapping: {train_mappings[task_name]}\n" - f"Validation Mapping: {val_mappings[task_name]}" - ) + return transformed_dataset, cell_id_mapping, num_labels_list + except KeyError as e: + print(f"Missing configuration or dataset key: {e}") + except Exception as e: + print(f"An error occurred while loading or preprocessing data: {e}") + return None, None, None -# Legacy functions for backward compatibility def preload_and_process_data(config): - """Preloads and preprocesses train and validation datasets.""" - data = prepare_data_loaders(config) - + # Load and preprocess data once + train_dataset, train_cell_id_mapping, num_labels_list = load_and_preprocess_data( + config["train_path"], config, dataset_type="train" + ) + val_dataset, val_cell_id_mapping, _ = load_and_preprocess_data( + config["val_path"], config, dataset_type="validation" + ) return ( - data["train_loader"].dataset, - data["train_cell_mapping"], - data["val_loader"].dataset, - data["val_cell_mapping"], - data["num_labels_list"] + train_dataset, + train_cell_id_mapping, + val_dataset, + val_cell_id_mapping, + num_labels_list, ) +def get_data_loader(preprocessed_dataset, batch_size): + nproc = os.cpu_count() ### I/O operations + + data_collator = DataCollatorForMultitaskCellClassification() + + loader = DataLoader( + preprocessed_dataset, + batch_size=batch_size, + shuffle=True, + collate_fn=data_collator, + num_workers=nproc, + pin_memory=True, + ) + return loader + + def preload_data(config): - """Preprocesses train and validation data for trials.""" - data = prepare_data_loaders(config) - return data["train_loader"], data["val_loader"] + # Preprocessing the data before the Optuna trials start + train_loader = get_data_loader("train", config) + val_loader = get_data_loader("val", config) + return train_loader, val_loader def load_and_preprocess_test_data(config): - """Loads and preprocesses test data.""" - test_dataset = StreamingMultiTaskDataset( - config["test_path"], - config, - is_test=True, - dataset_type="test" - ) - - return ( - test_dataset, - test_dataset.cell_id_mapping, - test_dataset.num_labels_list - ) + """ + Load and preprocess test data, treating it as unlabeled. + """ + return load_and_preprocess_data(config["test_path"], config, is_test=True) def prepare_test_loader(config): - """Prepares DataLoader for test data.""" - data = prepare_data_loaders(config, include_test=True) - return data["test_loader"], data["test_cell_mapping"], data["num_labels_list"] \ No newline at end of file + """ + Prepare DataLoader for the test dataset. + """ + test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data( + config + ) + test_loader = get_data_loader(test_dataset, config["batch_size"]) + return test_loader, cell_id_mapping, num_labels_list diff --git a/geneformer/mtl/eval_utils.py b/geneformer/mtl/eval_utils.py index 29f98214362f23e4492814948359af2b145abba8..0a8ea4babe4ab1e48cc56280ee03423075cf7563 100644 --- a/geneformer/mtl/eval_utils.py +++ b/geneformer/mtl/eval_utils.py @@ -1,16 +1,19 @@ -import os -import json -import torch import pandas as pd -from .data import prepare_test_loader +from .imports import * # noqa # isort:skip +from .data import prepare_test_loader # noqa # isort:skip from .model import GeneformerMultiTask + def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config): task_pred_labels = {task_name: [] for task_name in config["task_names"]} task_pred_probs = {task_name: [] for task_name in config["task_names"]} cell_ids = [] + # # Load task label mappings from pickle file + # with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f: + # task_label_mappings = pickle.load(f) + model.eval() with torch.no_grad(): for batch in test_loader: @@ -82,4 +85,4 @@ def load_and_evaluate_test_model(config): best_model.to(device) evaluate_test_dataset(best_model, device, test_loader, cell_id_mapping, config) - print("Evaluation completed.") \ No newline at end of file + print("Evaluation completed.") diff --git a/geneformer/mtl/imports.py b/geneformer/mtl/imports.py new file mode 100644 index 0000000000000000000000000000000000000000..4fe9e90945a10a3d79cc487fa15431f2915e5683 --- /dev/null +++ b/geneformer/mtl/imports.py @@ -0,0 +1,43 @@ +import functools +import gc +import json +import os +import pickle +import sys +import warnings +from enum import Enum +from itertools import chain +from typing import Dict, List, Optional, Union + +import numpy as np +import optuna +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from datasets import load_from_disk +from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, roc_curve +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import LabelEncoder +from torch.utils.data import DataLoader +from transformers import ( + AdamW, + BatchEncoding, + BertConfig, + BertModel, + DataCollatorForTokenClassification, + SpecialTokensMixin, + get_cosine_schedule_with_warmup, + get_linear_schedule_with_warmup, + get_scheduler, +) +from transformers.utils import logging, to_py_obj + +from .collators import DataCollatorForMultitaskCellClassification + +# local modules +from .data import get_data_loader, preload_and_process_data +from .model import GeneformerMultiTask +from .optuna_utils import create_optuna_study +from .utils import save_model diff --git a/geneformer/mtl/model.py b/geneformer/mtl/model.py index 0b03ff31bf81c458756b9afc7ead315e53c8d63b..393ebfad4f44f98d748845ea1ae81d66139988f5 100644 --- a/geneformer/mtl/model.py +++ b/geneformer/mtl/model.py @@ -118,4 +118,4 @@ class GeneformerMultiTask(nn.Module): f"Error during loss computation for task {task_id}: {e}" ) - return total_loss, logits, losses if labels is not None else logits \ No newline at end of file + return total_loss, logits, losses if labels is not None else logits diff --git a/geneformer/mtl/optuna_utils.py b/geneformer/mtl/optuna_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..47f375e90f4030e15feb7bc1245ffbba3e6a086e --- /dev/null +++ b/geneformer/mtl/optuna_utils.py @@ -0,0 +1,27 @@ +import optuna +from optuna.integration import TensorBoardCallback + + +def save_trial_callback(study, trial, trials_result_path): + with open(trials_result_path, "a") as f: + f.write( + f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n" + ) + + +def create_optuna_study(objective, n_trials, trials_result_path, tensorboard_log_dir): + study = optuna.create_study(direction="maximize") + + # init TensorBoard callback + tensorboard_callback = TensorBoardCallback( + dirname=tensorboard_log_dir, metric_name="F1 Macro" + ) + + # callback and TensorBoard callback + callbacks = [ + lambda study, trial: save_trial_callback(study, trial, trials_result_path), + tensorboard_callback, + ] + + study.optimize(objective, n_trials=n_trials, callbacks=callbacks) + return study diff --git a/geneformer/mtl/train.py b/geneformer/mtl/train.py index 02c00c706111b6b4198aca3c1dbb7bfcba4846a5..5dee1fb8baf594fb137dce3802a44cc0118f1558 100644 --- a/geneformer/mtl/train.py +++ b/geneformer/mtl/train.py @@ -1,707 +1,380 @@ import os +import random + +import numpy as np import pandas as pd import torch -import torch.distributed as dist -import torch.multiprocessing as mp -from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -import optuna -import functools -import time +from .imports import * from .model import GeneformerMultiTask -from .utils import ( - calculate_metrics, - get_layer_freeze_range, - set_seed, - initialize_wandb, - create_model, - setup_optimizer_and_scheduler, - save_model, - save_hyperparameters, - prepare_training_environment, - log_training_step, - log_validation_metrics, - save_validation_predictions, - setup_logging, - setup_distributed_environment, - train_distributed -) - - -class Trainer: - """Trainer class for multi-task learning""" - - def __init__(self, config): - self.config = config - self.device = None - self.model = None - self.optimizer = None - self.scheduler = None - self.writer = None - self.is_distributed = config.get("distributed_training", False) - self.local_rank = config.get("local_rank", 0) - self.is_main_process = not self.is_distributed or self.local_rank == 0 - - def train_epoch(self, train_loader, epoch): - """Train the model for one epoch.""" - epoch_start = time.time() - self.model.train() - - # For distributed training, we need to be aware of the global batch count - if self.is_distributed: - # Get world size for reporting - world_size = dist.get_world_size() - # Calculate total batches across all GPUs - total_batches_global = len(train_loader) * world_size if self.local_rank == 0 else len(train_loader) - else: - world_size = 1 - total_batches_global = len(train_loader) - - progress_bar = None - if self.is_main_process: - # Use the global batch count for progress reporting in distributed mode - progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{self.config['epochs']}", - total=len(train_loader)) - iterator = progress_bar - - # Report distributed training information - if self.is_distributed: - print(f"Distributed training: {world_size} GPUs, {len(train_loader)} batches per GPU, " - f"{total_batches_global} total batches globally") - else: - iterator = train_loader - - batch_times = [] - forward_times = [] - backward_times = [] - optimizer_times = [] - - # Get gradient accumulation steps from config (default to 1 if not specified) - accumulation_steps = self.config.get("gradient_accumulation_steps", 1) - - # Zero gradients at the beginning - self.optimizer.zero_grad() - - # Track loss for the entire epoch - total_loss = 0.0 - num_batches = 0 - accumulated_loss = 0.0 - - for batch_idx, batch in enumerate(iterator): - batch_start = time.time() - - input_ids = batch["input_ids"].to(self.device) - attention_mask = batch["attention_mask"].to(self.device) +from .utils import calculate_task_specific_metrics, get_layer_freeze_range + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def initialize_wandb(config): + if config.get("use_wandb", False): + import wandb + + wandb.init(project=config["wandb_project"], config=config) + print("Weights & Biases (wandb) initialized and will be used for logging.") + else: + print( + "Weights & Biases (wandb) is not enabled. Logging will use other methods." + ) + + +def create_model(config, num_labels_list, device): + model = GeneformerMultiTask( + config["pretrained_path"], + num_labels_list, + dropout_rate=config["dropout_rate"], + use_task_weights=config["use_task_weights"], + task_weights=config["task_weights"], + max_layers_to_freeze=config["max_layers_to_freeze"], + use_attention_pooling=config["use_attention_pooling"], + ) + if config["use_data_parallel"]: + model = nn.DataParallel(model) + return model.to(device) + + +def setup_optimizer_and_scheduler(model, config, total_steps): + optimizer = AdamW( + model.parameters(), + lr=config["learning_rate"], + weight_decay=config["weight_decay"], + ) + warmup_steps = int(config["warmup_ratio"] * total_steps) + + if config["lr_scheduler_type"] == "linear": + scheduler = get_linear_schedule_with_warmup( + optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps + ) + elif config["lr_scheduler_type"] == "cosine": + scheduler = get_cosine_schedule_with_warmup( + optimizer, + num_warmup_steps=warmup_steps, + num_training_steps=total_steps, + num_cycles=0.5, + ) + + return optimizer, scheduler + + +def train_epoch( + model, train_loader, optimizer, scheduler, device, config, writer, epoch +): + model.train() + progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}") + for batch_idx, batch in enumerate(progress_bar): + optimizer.zero_grad() + input_ids = batch["input_ids"].to(device) + attention_mask = batch["attention_mask"].to(device) + labels = [ + batch["labels"][task_name].to(device) for task_name in config["task_names"] + ] + + loss, _, _ = model(input_ids, attention_mask, labels) + loss.backward() + + if config["gradient_clipping"]: + torch.nn.utils.clip_grad_norm_(model.parameters(), config["max_grad_norm"]) + + optimizer.step() + scheduler.step() + + writer.add_scalar( + "Training Loss", loss.item(), epoch * len(train_loader) + batch_idx + ) + if config.get("use_wandb", False): + import wandb + + wandb.log({"Training Loss": loss.item()}) + + # Update progress bar + progress_bar.set_postfix({"loss": f"{loss.item():.4f}"}) + + return loss.item() # Return the last batch loss + + +def validate_model(model, val_loader, device, config): + model.eval() + val_loss = 0.0 + task_true_labels = {task_name: [] for task_name in config["task_names"]} + task_pred_labels = {task_name: [] for task_name in config["task_names"]} + task_pred_probs = {task_name: [] for task_name in config["task_names"]} + + with torch.no_grad(): + for batch in val_loader: + input_ids = batch["input_ids"].to(device) + attention_mask = batch["attention_mask"].to(device) labels = [ - batch["labels"][task_name].to(self.device) for task_name in self.config["task_names"] + batch["labels"][task_name].to(device) + for task_name in config["task_names"] ] + loss, logits, _ = model(input_ids, attention_mask, labels) + val_loss += loss.item() + + for sample_idx in range(len(batch["input_ids"])): + for i, task_name in enumerate(config["task_names"]): + true_label = batch["labels"][task_name][sample_idx].item() + pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item() + pred_prob = ( + torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy() + ) + task_true_labels[task_name].append(true_label) + task_pred_labels[task_name].append(pred_label) + task_pred_probs[task_name].append(pred_prob) + + val_loss /= len(val_loader) + return val_loss, task_true_labels, task_pred_labels, task_pred_probs + + +def log_metrics(task_metrics, val_loss, config, writer, epochs): + for task_name, metrics in task_metrics.items(): + print( + f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}" + ) + if config.get("use_wandb", False): + import wandb - forward_start = time.time() - loss, _, _ = self.model(input_ids, attention_mask, labels) - - # Scale loss by accumulation steps for gradient accumulation - if accumulation_steps > 1: - loss = loss / accumulation_steps - - forward_end = time.time() - forward_times.append(forward_end - forward_start) - - # Track loss - store the unscaled loss for reporting - unscaled_loss = loss.item() * (1 if accumulation_steps == 1 else accumulation_steps) - total_loss += unscaled_loss - num_batches += 1 - accumulated_loss += loss.item() # For gradient accumulation tracking - - backward_start = time.time() - - # Use no_sync() for all but the last accumulation step to avoid unnecessary communication - if self.is_distributed and accumulation_steps > 1: - # If this is not the last accumulation step or the last batch - if (batch_idx + 1) % accumulation_steps != 0 and (batch_idx + 1) != len(train_loader): - with self.model.no_sync(): - loss.backward() - else: - loss.backward() - else: - # Non-distributed training or accumulation_steps=1 - loss.backward() - - backward_end = time.time() - backward_times.append(backward_end - backward_start) - - # Only update weights after accumulation_steps or at the end of the epoch - if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader): - if self.config["gradient_clipping"]: - torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config["max_grad_norm"]) - - optimizer_start = time.time() - self.optimizer.step() - self.scheduler.step() - self.optimizer.zero_grad() - optimizer_end = time.time() - optimizer_times.append(optimizer_end - optimizer_start) - - # Log after optimizer step - if self.is_main_process: - # Calculate running average loss - avg_loss = total_loss / num_batches - - log_training_step(avg_loss, self.writer, self.config, epoch, len(train_loader), batch_idx) - - # Update progress bar with just the running average loss - progress_bar.set_postfix({"loss": f"{avg_loss:.4f}"}) - - accumulated_loss = 0.0 - else: - optimizer_times.append(0) # No optimizer step taken - - batch_end = time.time() - batch_times.append(batch_end - batch_start) - - epoch_end = time.time() - - # Calculate average loss for the epoch - epoch_avg_loss = total_loss / num_batches - - # If distributed, gather losses from all processes to compute global average - if self.is_distributed: - # Create a tensor to hold the loss - loss_tensor = torch.tensor([epoch_avg_loss], device=self.device) - # Gather losses from all processes - all_losses = [torch.zeros_like(loss_tensor) for _ in range(dist.get_world_size())] - dist.all_gather(all_losses, loss_tensor) - # Compute the global average loss across all processes - epoch_avg_loss = torch.mean(torch.stack(all_losses)).item() - - if self.is_main_process: - # douhble check if batch_size has already been adjusted for world_size in the config - # This avoids double-counting the effective batch size - per_gpu_batch_size = self.config['batch_size'] - total_effective_batch = per_gpu_batch_size * accumulation_steps * world_size - - print(f"Epoch {epoch+1} timing:") - print(f" Total epoch time: {epoch_end - epoch_start:.2f}s") - print(f" Average batch time: {sum(batch_times)/len(batch_times):.4f}s") - print(f" Average forward time: {sum(forward_times)/len(forward_times):.4f}s") - print(f" Average backward time: {sum(backward_times)/len(backward_times):.4f}s") - print(f" Average optimizer time: {sum([t for t in optimizer_times if t > 0])/max(1, len([t for t in optimizer_times if t > 0])):.4f}s") - print(f" Gradient accumulation steps: {accumulation_steps}") - print(f" Batch size per GPU: {per_gpu_batch_size}") - print(f" Effective global batch size: {total_effective_batch}") - print(f" Average training loss: {epoch_avg_loss:.4f}") - if self.is_distributed: - print(f" Total batches processed across all GPUs: {total_batches_global}") - print(f" Communication optimization: Using no_sync() for gradient accumulation") - - return epoch_avg_loss # Return the average loss for the epoch - - def validate_model(self, val_loader): - val_start = time.time() - self.model.eval() - val_loss = 0.0 - task_true_labels = {task_name: [] for task_name in self.config["task_names"]} - task_pred_labels = {task_name: [] for task_name in self.config["task_names"]} - task_pred_probs = {task_name: [] for task_name in self.config["task_names"]} - - val_cell_ids = {} - sample_counter = 0 - - batch_times = [] - - # Print validation dataset size - if self.is_main_process: - print(f"Validation dataset size: {len(val_loader.dataset)} samples") - print(f"Number of validation batches: {len(val_loader)}") - - if self.is_distributed: - world_size = dist.get_world_size() - print(f"Distributed validation: {world_size} GPUs") - if hasattr(val_loader, 'sampler') and hasattr(val_loader.sampler, 'num_samples'): - samples_per_gpu = val_loader.sampler.num_samples - print(f"Each GPU processes {samples_per_gpu} validation samples") - print(f"Total validation samples processed: {samples_per_gpu * world_size}") - - with torch.no_grad(): - for batch in val_loader: - batch_start = time.time() - input_ids = batch["input_ids"].to(self.device) - attention_mask = batch["attention_mask"].to(self.device) - labels = [ - batch["labels"][task_name].to(self.device) - for task_name in self.config["task_names"] - ] - loss, logits, _ = self.model(input_ids, attention_mask, labels) - val_loss += loss.item() - - if "cell_id" in batch: - for i, cell_id in enumerate(batch["cell_id"]): - # Store the actual index for later mapping to unique_cell_id - val_cell_ids[sample_counter + i] = cell_id.item() - - for sample_idx in range(len(batch["input_ids"])): - for i, task_name in enumerate(self.config["task_names"]): - true_label = batch["labels"][task_name][sample_idx].item() - pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item() - # Store the full probability distribution - pred_prob = torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy().tolist() - task_true_labels[task_name].append(true_label) - task_pred_labels[task_name].append(pred_label) - task_pred_probs[task_name].append(pred_prob) - - # Update current index for cell ID tracking - sample_counter += len(batch["input_ids"]) - - batch_end = time.time() - batch_times.append(batch_end - batch_start) - - # norm validation loss by the number of batches - val_loss /= len(val_loader) - - # distributed, gather results from all processes - if self.is_distributed: - # Create tensors to hold the local results - loss_tensor = torch.tensor([val_loss], device=self.device) - gathered_losses = [torch.zeros_like(loss_tensor) for _ in range(dist.get_world_size())] - dist.all_gather(gathered_losses, loss_tensor) - - # Compute average loss across all processes - val_loss = torch.mean(torch.cat(gathered_losses)).item() - - world_size = dist.get_world_size() - - if self.is_main_process: - print(f"Collected predictions from rank {self.local_rank}") - print(f"Number of samples processed by this rank: {sample_counter}") - - val_end = time.time() - - if self.is_main_process: - print(f"Validation timing:") - print(f" Total validation time: {val_end - val_start:.2f}s") - print(f" Average batch time: {sum(batch_times)/len(batch_times):.4f}s") - print(f" Collected {len(val_cell_ids)} cell indices from validation data") - print(f" Processed {sample_counter} total samples during validation") - - # Print number of samples per task - for task_name in self.config["task_names"]: - print(f" Task {task_name}: {len(task_true_labels[task_name])} samples") - - return val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids - - def train(self, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list): - """Train the model and return validation loss and trained model.""" - if self.config.get("use_wandb", False) and self.is_main_process: - initialize_wandb(self.config) - - # Create model - self.model = create_model(self.config, num_labels_list, self.device, self.is_distributed, self.local_rank) - - # Setup optimizer and scheduler - total_steps = len(train_loader) * self.config["epochs"] - self.optimizer, self.scheduler = setup_optimizer_and_scheduler(self.model, self.config, total_steps) - - # Training loop - if self.is_main_process: - epoch_progress = tqdm(range(self.config["epochs"]), desc="Training Progress") - else: - epoch_progress = range(self.config["epochs"]) - - best_val_loss = float('inf') - train_losses = [] - - with setup_logging(self.config) as self.writer: - for epoch in epoch_progress: - if self.is_distributed: - train_loader.sampler.set_epoch(epoch) - - train_loss = self.train_epoch(train_loader, epoch) - train_losses.append(train_loss) - - # Run validation after each epoch if configured to do so - if self.config.get("validate_each_epoch", False): - val_loss, _, _, _, _ = self.validate_model(val_loader) - if val_loss < best_val_loss: - best_val_loss = val_loss - - if self.is_main_process: - epoch_progress.set_postfix({ - "train_loss": f"{train_loss:.4f}", - "val_loss": f"{val_loss:.4f}", - "best_val_loss": f"{best_val_loss:.4f}" - }) - else: - if self.is_main_process: - epoch_progress.set_postfix({ - "train_loss": f"{train_loss:.4f}" - }) - - val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids = self.validate_model(val_loader) - task_metrics = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific") - - if self.is_main_process: - log_validation_metrics(task_metrics, val_loss, self.config, self.writer, self.config["epochs"]) - - # Save validation predictions - save_validation_predictions( - val_cell_ids, - task_true_labels, - task_pred_labels, - task_pred_probs, - {**self.config, "val_cell_mapping": val_cell_id_mapping} # Include the mapping - ) - - if self.config.get("use_wandb", False): - import wandb - wandb.finish() - - print(f"\nTraining Summary:") - print(f" Final Training Loss: {train_losses[-1]:.4f}") - print(f" Final Validation Loss: {val_loss:.4f}") - for task_name, metrics in task_metrics.items(): - print(f" {task_name} - F1: {metrics['f1']:.4f}, Accuracy: {metrics['accuracy']:.4f}") - - return val_loss, self.model # Return both the validation loss and the trained model - - def setup(self, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list): - if self.is_distributed: - self.device = torch.device(f"cuda:{self.local_rank}") - else: - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - self.model = create_model(self.config, num_labels_list, self.device) - - # war model w DDP - if self.is_distributed: - self.model = DDP(self.model, device_ids=[self.local_rank]) - - # communication hook to optimize gradient synchronization - from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks - - # default hook which maintains full precision - self.model.register_comm_hook( - state=None, - hook=comm_hooks.allreduce_hook + wandb.log( + { + f"{task_name} Validation F1 Macro": metrics["f1"], + f"{task_name} Validation Accuracy": metrics["accuracy"], + } ) - - print(f"Rank {self.local_rank}: Registered communication hook for optimized gradient synchronization") - - print(f"Rank {self.local_rank}: Using samplers created in distributed worker") - print(f"Rank {self.local_rank}: Training dataset has {len(train_loader.dataset)} samples") - if hasattr(train_loader, 'sampler') and hasattr(train_loader.sampler, 'num_samples'): - print(f"Rank {self.local_rank}: This GPU will process {train_loader.sampler.num_samples} training samples per epoch") - - if hasattr(val_loader, 'sampler') and hasattr(val_loader.sampler, 'num_samples'): - print(f"Rank {self.local_rank}: This GPU will process {val_loader.sampler.num_samples} validation samples") - - # Set up optimizer and scheduler - self.optimizer, self.scheduler = setup_optimizer_and_scheduler( - self.model, self.config, len(train_loader) + + writer.add_scalar("Validation Loss", val_loss, epochs) + for task_name, metrics in task_metrics.items(): + writer.add_scalar(f"{task_name} - Validation F1 Macro", metrics["f1"], epochs) + writer.add_scalar( + f"{task_name} - Validation Accuracy", metrics["accuracy"], epochs ) - if self.is_main_process and self.config.get("use_wandb", False): - initialize_wandb(self.config) - - return train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list - - -def train_model(config, device, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list): - """Train a model with the given configuration and data.""" - # Check if distributed training is enabled - if config.get("distributed_training", False): - # Check if we have multiple GPUs - if torch.cuda.device_count() > 1: - result = train_distributed( - Trainer, - config, - train_loader, - val_loader, - train_cell_id_mapping, - val_cell_id_mapping, - num_labels_list + +def save_validation_predictions( + val_cell_id_mapping, + task_true_labels, + task_pred_labels, + task_pred_probs, + config, + trial_number=None, +): + if trial_number is not None: + trial_results_dir = os.path.join(config["results_dir"], f"trial_{trial_number}") + os.makedirs(trial_results_dir, exist_ok=True) + val_preds_file = os.path.join(trial_results_dir, "val_preds.csv") + else: + val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv") + + rows = [] + for sample_idx in range(len(val_cell_id_mapping)): + row = {"Cell ID": val_cell_id_mapping[sample_idx]} + for task_name in config["task_names"]: + row[f"{task_name} True"] = task_true_labels[task_name][sample_idx] + row[f"{task_name} Pred"] = task_pred_labels[task_name][sample_idx] + row[f"{task_name} Probabilities"] = ",".join( + map(str, task_pred_probs[task_name][sample_idx]) ) - if result is not None: - return result - else: - print("Distributed training requested but only one GPU found. Falling back to single GPU training.") - config["distributed_training"] = False - - # Non-distributed training - trainer = Trainer(config) - trainer.device = device - return trainer.train(train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list) + rows.append(row) + df = pd.DataFrame(rows) + df.to_csv(val_preds_file, index=False) + print(f"Validation predictions saved to {val_preds_file}") -def objective( - trial, + +def train_model( + config, + device, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, - config, - device, ): - """Objective function for Optuna hyperparameter optimization.""" set_seed(config["seed"]) initialize_wandb(config) - trial_config = config.copy() - - # Suggest hyperparameters for this trial - for param_name, param_config in config["hyperparameters"].items(): - if param_name == "lr_scheduler_type": - trial_config[param_name] = trial.suggest_categorical( - param_name, param_config["choices"] - ) - elif param_name == "task_weights" and config["use_task_weights"]: - weights = [ - trial.suggest_float( - f"task_weight_{i}", - param_config["low"], - param_config["high"], - ) - for i in range(len(num_labels_list)) - ] - weight_sum = sum(weights) - trial_config[param_name] = [w / weight_sum for w in weights] - elif "log" in param_config and param_config["log"]: - trial_config[param_name] = trial.suggest_float( - param_name, param_config["low"], param_config["high"], log=True - ) - else: - trial_config[param_name] = trial.suggest_float( - param_name, param_config["low"], param_config["high"] - ) - - # Set appropriate max layers to freeze based on pretrained model - if "max_layers_to_freeze" in trial_config: - freeze_range = get_layer_freeze_range(trial_config["pretrained_path"]) - trial_config["max_layers_to_freeze"] = int(trial.suggest_int( - "max_layers_to_freeze", - freeze_range["min"], - freeze_range["max"] - )) - - trial_config["run_name"] = f"trial_{trial.number}" - - # Handle distributed training for this trial - if trial_config.get("distributed_training", False) and torch.cuda.device_count() > 1: - manager = mp.Manager() - shared_dict = manager.dict() - - train_distributed( - Trainer, - trial_config, - train_loader, - val_loader, - train_cell_id_mapping, - val_cell_id_mapping, - num_labels_list, - trial.number, - shared_dict - ) - - val_loss = shared_dict.get('val_loss', float('inf')) - task_metrics = shared_dict.get('task_metrics', {}) - - trial.set_user_attr("model_state_dict", shared_dict.get('model_state_dict', {})) - trial.set_user_attr("task_weights", trial_config["task_weights"]) - - if config.get("use_wandb", False): - import wandb - wandb.log({ - "trial_number": trial.number, - "val_loss": val_loss, - **{f"{task_name}_f1": metrics["f1"] for task_name, metrics in task_metrics.items()}, - **{f"{task_name}_accuracy": metrics["accuracy"] for task_name, metrics in task_metrics.items()}, - }) - wandb.finish() - - return val_loss - - with setup_logging(trial_config) as writer: - trainer = Trainer(trial_config) - trainer.device = device - trainer.writer = writer - - # Create model with trial hyperparameters - trainer.model = create_model(trial_config, num_labels_list, device) - total_steps = len(train_loader) * config["epochs"] - trainer.optimizer, trainer.scheduler = setup_optimizer_and_scheduler(trainer.model, trial_config, total_steps) - - # Training loop - for epoch in range(config["epochs"]): - trainer.train_epoch(train_loader, epoch) - - val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids = trainer.validate_model(val_loader) - task_metrics = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific") - - # Log metrics - log_validation_metrics(task_metrics, val_loss, trial_config, writer, config["epochs"]) - - # Save validation predictions - save_validation_predictions( - val_cell_ids, - task_true_labels, - task_pred_labels, - task_pred_probs, - {**trial_config, "val_cell_mapping": val_cell_id_mapping}, - trial.number, + model = create_model(config, num_labels_list, device) + total_steps = len(train_loader) * config["epochs"] + optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps) + + log_dir = os.path.join(config["tensorboard_log_dir"], "manual_run") + writer = SummaryWriter(log_dir=log_dir) + + epoch_progress = tqdm(range(config["epochs"]), desc="Training Progress") + for epoch in epoch_progress: + last_loss = train_epoch( + model, train_loader, optimizer, scheduler, device, config, writer, epoch ) + epoch_progress.set_postfix({"last_loss": f"{last_loss:.4f}"}) - # Store model state dict and task weights in trial user attributes - trial.set_user_attr("model_state_dict", trainer.model.state_dict()) - trial.set_user_attr("task_weights", trial_config["task_weights"]) + val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model( + model, val_loader, device, config + ) + task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels) - # Report intermediate value to Optuna - trial.report(val_loss, config["epochs"]) - if trial.should_prune(): - raise optuna.TrialPruned() + log_metrics(task_metrics, val_loss, config, writer, config["epochs"]) + writer.close() - if config.get("use_wandb", False): - import wandb - wandb.log( - { - "trial_number": trial.number, - "val_loss": val_loss, - **{f"{task_name}_f1": metrics["f1"] for task_name, metrics in task_metrics.items()}, - **{f"{task_name}_accuracy": metrics["accuracy"] for task_name, metrics in task_metrics.items()}, - **{k: v for k, v in trial_config.items() if k in [ - "learning_rate", "warmup_ratio", "weight_decay", "dropout_rate", - "lr_scheduler_type", "use_attention_pooling", "max_layers_to_freeze" - ]}, - } - ) - wandb.finish() + save_validation_predictions( + val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config + ) - return val_loss + if config.get("use_wandb", False): + import wandb + wandb.finish() -def run_manual_tuning(config): - """Run training with manually specified hyperparameters.""" - ( - device, - train_loader, - val_loader, - train_cell_id_mapping, - val_cell_id_mapping, - num_labels_list, - ) = prepare_training_environment(config) + print(f"\nFinal Validation Loss: {val_loss:.4f}") + return val_loss, model # Return both the validation loss and the trained model - print("\nManual hyperparameters being used:") - for key, value in config["manual_hyperparameters"].items(): - print(f"{key}: {value}") - print() - # Update config with manual hyperparameters - for key, value in config["manual_hyperparameters"].items(): - config[key] = value +def objective( + trial, + train_loader, + val_loader, + train_cell_id_mapping, + val_cell_id_mapping, + num_labels_list, + config, + device, +): + set_seed(config["seed"]) # Set the seed before each trial + initialize_wandb(config) - # Train the model - val_loss, trained_model = train_model( - config, - device, - train_loader, - val_loader, - train_cell_id_mapping, - val_cell_id_mapping, - num_labels_list, + # Hyperparameters + config["learning_rate"] = trial.suggest_float( + "learning_rate", + config["hyperparameters"]["learning_rate"]["low"], + config["hyperparameters"]["learning_rate"]["high"], + log=config["hyperparameters"]["learning_rate"]["log"], + ) + config["warmup_ratio"] = trial.suggest_float( + "warmup_ratio", + config["hyperparameters"]["warmup_ratio"]["low"], + config["hyperparameters"]["warmup_ratio"]["high"], + ) + config["weight_decay"] = trial.suggest_float( + "weight_decay", + config["hyperparameters"]["weight_decay"]["low"], + config["hyperparameters"]["weight_decay"]["high"], + ) + config["dropout_rate"] = trial.suggest_float( + "dropout_rate", + config["hyperparameters"]["dropout_rate"]["low"], + config["hyperparameters"]["dropout_rate"]["high"], + ) + config["lr_scheduler_type"] = trial.suggest_categorical( + "lr_scheduler_type", config["hyperparameters"]["lr_scheduler_type"]["choices"] + ) + config["use_attention_pooling"] = trial.suggest_categorical( + "use_attention_pooling", [False] ) - print(f"\nValidation loss with manual hyperparameters: {val_loss}") - - # Save the trained model - only if not using distributed training - # (distributed training saves the model in the worker) - if not config.get("distributed_training", False): - model_save_directory = os.path.join( - config["model_save_path"], "GeneformerMultiTask" - ) - save_model(trained_model, model_save_directory) - - # Save the hyperparameters - hyperparams_to_save = { - **config["manual_hyperparameters"], - "dropout_rate": config["dropout_rate"], - "use_task_weights": config["use_task_weights"], - "task_weights": config["task_weights"], - "max_layers_to_freeze": config["max_layers_to_freeze"], - "use_attention_pooling": config["use_attention_pooling"], - } - save_hyperparameters(model_save_directory, hyperparams_to_save) + if config["use_task_weights"]: + config["task_weights"] = [ + trial.suggest_float( + f"task_weight_{i}", + config["hyperparameters"]["task_weights"]["low"], + config["hyperparameters"]["task_weights"]["high"], + ) + for i in range(len(num_labels_list)) + ] + weight_sum = sum(config["task_weights"]) + config["task_weights"] = [ + weight / weight_sum for weight in config["task_weights"] + ] + else: + config["task_weights"] = None + + # Dynamic range for max_layers_to_freeze + freeze_range = get_layer_freeze_range(config["pretrained_path"]) + config["max_layers_to_freeze"] = trial.suggest_int( + "max_layers_to_freeze", + freeze_range["min"], + freeze_range["max"] + ) - return val_loss + model = create_model(config, num_labels_list, device) + total_steps = len(train_loader) * config["epochs"] + optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps) + log_dir = os.path.join(config["tensorboard_log_dir"], f"trial_{trial.number}") + writer = SummaryWriter(log_dir=log_dir) -def run_optuna_study(config): - """Run hyperparameter optimization using Optuna.""" - # Prepare training environment - ( - device, - train_loader, - val_loader, - train_cell_id_mapping, - val_cell_id_mapping, - num_labels_list, - ) = prepare_training_environment(config) - - # If manual hyperparameters are specified, use them instead of running Optuna - if config.get("use_manual_hyperparameters", False): - return run_manual_tuning(config) - - # Create a partial function with fixed arguments for the objective - objective_with_config_and_data = functools.partial( - objective, - train_loader=train_loader, - val_loader=val_loader, - train_cell_id_mapping=train_cell_id_mapping, - val_cell_id_mapping=val_cell_id_mapping, - num_labels_list=num_labels_list, - config=config, - device=device, - ) + for epoch in range(config["epochs"]): + train_epoch( + model, train_loader, optimizer, scheduler, device, config, writer, epoch + ) - # Create and run the Optuna study - study = optuna.create_study( - direction="minimize", # Minimize validation loss - study_name=config["study_name"], - # storage=config["storage"], - load_if_exists=True, + val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model( + model, val_loader, device, config ) + task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels) - study.optimize(objective_with_config_and_data, n_trials=config["n_trials"]) + log_metrics(task_metrics, val_loss, config, writer, config["epochs"]) + writer.close() - # After finding the best trial - best_params = study.best_trial.params - best_task_weights = study.best_trial.user_attrs["task_weights"] - print("Saving the best model and its hyperparameters...") - - # Create a model with the best hyperparameters - best_model = GeneformerMultiTask( - config["pretrained_path"], - num_labels_list, - dropout_rate=best_params["dropout_rate"], - use_task_weights=config["use_task_weights"], - task_weights=best_task_weights, - max_layers_to_freeze=best_params.get("max_layers_to_freeze", 0), - use_attention_pooling=best_params.get("use_attention_pooling", False), + save_validation_predictions( + val_cell_id_mapping, + task_true_labels, + task_pred_labels, + task_pred_probs, + config, + trial.number, ) - # Get the best model state dictionary - best_model_state_dict = study.best_trial.user_attrs["model_state_dict"] + trial.set_user_attr("model_state_dict", model.state_dict()) + trial.set_user_attr("task_weights", config["task_weights"]) - best_model_state_dict = { - k.replace("module.", ""): v for k, v in best_model_state_dict.items() - } + trial.report(val_loss, config["epochs"]) - best_model.load_state_dict(best_model_state_dict, strict=False) + if trial.should_prune(): + raise optuna.TrialPruned() - model_save_directory = os.path.join( - config["model_save_path"], "GeneformerMultiTask" - ) - save_model(best_model, model_save_directory) + if config.get("use_wandb", False): + import wandb - save_hyperparameters(model_save_directory, {**best_params, "task_weights": best_task_weights}) + wandb.log( + { + "trial_number": trial.number, + "val_loss": val_loss, + **{ + f"{task_name}_f1": metrics["f1"] + for task_name, metrics in task_metrics.items() + }, + **{ + f"{task_name}_accuracy": metrics["accuracy"] + for task_name, metrics in task_metrics.items() + }, + **{ + k: v + for k, v in config.items() + if k + in [ + "learning_rate", + "warmup_ratio", + "weight_decay", + "dropout_rate", + "lr_scheduler_type", + "use_attention_pooling", + "max_layers_to_freeze", + ] + }, + } + ) + wandb.finish() - return study.best_trial.value # Return the best validation loss + return val_loss diff --git a/geneformer/mtl/train_utils.py b/geneformer/mtl/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..430994a37a53dcde99666a7b5a4d99532e9bc8ba --- /dev/null +++ b/geneformer/mtl/train_utils.py @@ -0,0 +1,161 @@ +import random + +from .data import get_data_loader, preload_and_process_data +from .imports import * +from .model import GeneformerMultiTask +from .train import objective, train_model +from .utils import save_model + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def run_manual_tuning(config): + # Set seed for reproducibility + set_seed(config["seed"]) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + ( + train_dataset, + train_cell_id_mapping, + val_dataset, + val_cell_id_mapping, + num_labels_list, + ) = preload_and_process_data(config) + train_loader = get_data_loader(train_dataset, config["batch_size"]) + val_loader = get_data_loader(val_dataset, config["batch_size"]) + + # Print the manual hyperparameters being used + print("\nManual hyperparameters being used:") + for key, value in config["manual_hyperparameters"].items(): + print(f"{key}: {value}") + print() # Add an empty line for better readability + + # Use the manual hyperparameters + for key, value in config["manual_hyperparameters"].items(): + config[key] = value + + # Train the model + val_loss, trained_model = train_model( + config, + device, + train_loader, + val_loader, + train_cell_id_mapping, + val_cell_id_mapping, + num_labels_list, + ) + + print(f"\nValidation loss with manual hyperparameters: {val_loss}") + + # Save the trained model + model_save_directory = os.path.join( + config["model_save_path"], "GeneformerMultiTask" + ) + save_model(trained_model, model_save_directory) + + # Save the hyperparameters + hyperparams_to_save = { + **config["manual_hyperparameters"], + "dropout_rate": config["dropout_rate"], + "use_task_weights": config["use_task_weights"], + "task_weights": config["task_weights"], + "max_layers_to_freeze": config["max_layers_to_freeze"], + "use_attention_pooling": config["use_attention_pooling"], + } + hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json") + with open(hyperparams_path, "w") as f: + json.dump(hyperparams_to_save, f) + print(f"Manual hyperparameters saved to {hyperparams_path}") + + return val_loss + + +def run_optuna_study(config): + # Set seed for reproducibility + set_seed(config["seed"]) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + ( + train_dataset, + train_cell_id_mapping, + val_dataset, + val_cell_id_mapping, + num_labels_list, + ) = preload_and_process_data(config) + train_loader = get_data_loader(train_dataset, config["batch_size"]) + val_loader = get_data_loader(val_dataset, config["batch_size"]) + + if config["use_manual_hyperparameters"]: + train_model( + config, + device, + train_loader, + val_loader, + train_cell_id_mapping, + val_cell_id_mapping, + num_labels_list, + ) + else: + objective_with_config_and_data = functools.partial( + objective, + train_loader=train_loader, + val_loader=val_loader, + train_cell_id_mapping=train_cell_id_mapping, + val_cell_id_mapping=val_cell_id_mapping, + num_labels_list=num_labels_list, + config=config, + device=device, + ) + + study = optuna.create_study( + direction="minimize", # Minimize validation loss + study_name=config["study_name"], + # storage=config["storage"], + load_if_exists=True, + ) + + study.optimize(objective_with_config_and_data, n_trials=config["n_trials"]) + + # After finding the best trial + best_params = study.best_trial.params + best_task_weights = study.best_trial.user_attrs["task_weights"] + print("Saving the best model and its hyperparameters...") + + # Saving model as before + best_model = GeneformerMultiTask( + config["pretrained_path"], + num_labels_list, + dropout_rate=best_params["dropout_rate"], + use_task_weights=config["use_task_weights"], + task_weights=best_task_weights, + ) + + # Get the best model state dictionary + best_model_state_dict = study.best_trial.user_attrs["model_state_dict"] + + # Remove the "module." prefix from the state dictionary keys if present + best_model_state_dict = { + k.replace("module.", ""): v for k, v in best_model_state_dict.items() + } + + # Load the modified state dictionary into the model, skipping unexpected keys + best_model.load_state_dict(best_model_state_dict, strict=False) + + model_save_directory = os.path.join( + config["model_save_path"], "GeneformerMultiTask" + ) + save_model(best_model, model_save_directory) + + # Additionally, save the best hyperparameters and task weights + hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json") + + with open(hyperparams_path, "w") as f: + json.dump({**best_params, "task_weights": best_task_weights}, f) + print(f"Best hyperparameters and task weights saved to {hyperparams_path}") diff --git a/geneformer/mtl/utils.py b/geneformer/mtl/utils.py index 1a59a63b4a4a9d2642ed5d504bd1a38b054d10be..5de5079ffdefb853a183038a6b3956de42f19978 100644 --- a/geneformer/mtl/utils.py +++ b/geneformer/mtl/utils.py @@ -1,641 +1,129 @@ -from typing import Dict, List, Optional, Union -import json import os -import pickle -import random -import torch -import numpy as np -import optuna +import shutil + from sklearn.metrics import accuracy_score, f1_score from sklearn.preprocessing import LabelEncoder -from torch.utils.tensorboard import SummaryWriter -from transformers import AutoConfig, BertConfig, BertModel, get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup -from torch.optim import AdamW -import pandas as pd -import torch.distributed as dist -from torch.nn.parallel import DistributedDataParallel as DDP -import torch.multiprocessing as mp -from contextlib import contextmanager - - -def set_seed(seed): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - - -def initialize_wandb(config): - if config.get("use_wandb", False): - import wandb - wandb.init( - project=config.get("wandb_project", "geneformer_multitask"), - name=config.get("run_name", "experiment"), - config=config, - reinit=True, - ) +from transformers import AutoConfig, BertConfig, BertModel - -def create_model(config, num_labels_list, device, is_distributed=False, local_rank=0): - """Create and initialize the model based on configuration.""" - from .model import GeneformerMultiTask - - model = GeneformerMultiTask( - config["pretrained_path"], - num_labels_list, - dropout_rate=config.get("dropout_rate", 0.1), - use_task_weights=config.get("use_task_weights", False), - task_weights=config.get("task_weights", None), - max_layers_to_freeze=config.get("max_layers_to_freeze", 0), - use_attention_pooling=config.get("use_attention_pooling", False), - ) - - # Move model to device - model.to(device) - - if is_distributed: - model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) - - return model - - -def setup_optimizer_and_scheduler(model, config, total_steps): - """Set up optimizer and learning rate scheduler.""" - no_decay = ["bias", "LayerNorm.weight"] - optimizer_grouped_parameters = [ - { - "params": [p for n, p in model.named_parameters() - if not any(nd in n for nd in no_decay) and p.requires_grad], - "weight_decay": config["weight_decay"], - }, - { - "params": [p for n, p in model.named_parameters() - if any(nd in n for nd in no_decay) and p.requires_grad], - "weight_decay": 0.0, - }, - ] - - optimizer = AdamW( - optimizer_grouped_parameters, - lr=config["learning_rate"], - eps=config.get("adam_epsilon", 1e-8) - ) - - # Prepare scheduler - warmup_steps = int(total_steps * config["warmup_ratio"]) - - scheduler_map = { - "linear": get_linear_schedule_with_warmup, - "cosine": get_cosine_schedule_with_warmup - } - - scheduler_fn = scheduler_map.get(config["lr_scheduler_type"]) - if not scheduler_fn: - raise ValueError(f"Unsupported scheduler type: {config['lr_scheduler_type']}") - - scheduler = scheduler_fn(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps) - - return optimizer, scheduler +from .imports import * def save_model(model, model_save_directory): - """Save model weights and configuration.""" - os.makedirs(model_save_directory, exist_ok=True) - - # Handle DDP model - if isinstance(model, DDP): - model_to_save = model.module + if not os.path.exists(model_save_directory): + os.makedirs(model_save_directory) + + # Get the state dict + if isinstance(model, nn.DataParallel): + model_state_dict = ( + model.module.state_dict() + ) # Use model.module to access the underlying model else: - model_to_save = model - - model_state_dict = model_to_save.state_dict() + model_state_dict = model.state_dict() + + # Remove the "module." prefix from the keys if present + model_state_dict = { + k.replace("module.", ""): v for k, v in model_state_dict.items() + } model_save_path = os.path.join(model_save_directory, "pytorch_model.bin") torch.save(model_state_dict, model_save_path) # Save the model configuration - model_to_save.config.to_json_file(os.path.join(model_save_directory, "config.json")) - - print(f"Model and configuration saved to {model_save_directory}") - - -def save_hyperparameters(model_save_directory, hyperparams): - """Save hyperparameters to a JSON file.""" - hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json") - with open(hyperparams_path, "w") as f: - json.dump(hyperparams, f) - print(f"Hyperparameters saved to {hyperparams_path}") - - -def calculate_metrics(labels=None, preds=None, task_data=None, metric_type="task_specific", return_format="dict"): - if metric_type == "single": - # Calculate metrics for a single task - if labels is None or preds is None: - raise ValueError("Labels and predictions must be provided for single task metrics") - - task_name = None - if isinstance(labels, dict) and len(labels) == 1: - task_name = list(labels.keys())[0] - labels = labels[task_name] - preds = preds[task_name] - - f1 = f1_score(labels, preds, average="macro") - accuracy = accuracy_score(labels, preds) - - if return_format == "tuple": - return f1, accuracy - - result = {"f1": f1, "accuracy": accuracy} - if task_name: - return {task_name: result} - return result - - elif metric_type == "task_specific": - # Calculate metrics for multiple tasks - if task_data: - result = {} - for task_name, (task_labels, task_preds) in task_data.items(): - f1 = f1_score(task_labels, task_preds, average="macro") - accuracy = accuracy_score(task_labels, task_preds) - result[task_name] = {"f1": f1, "accuracy": accuracy} - return result - elif isinstance(labels, dict) and isinstance(preds, dict): - result = {} - for task_name in labels: - if task_name in preds: - f1 = f1_score(labels[task_name], preds[task_name], average="macro") - accuracy = accuracy_score(labels[task_name], preds[task_name]) - result[task_name] = {"f1": f1, "accuracy": accuracy} - return result - else: - raise ValueError("For task_specific metrics, either task_data or labels and preds dictionaries must be provided") - - elif metric_type == "combined": - # Calculate combined metrics across all tasks - if labels is None or preds is None: - raise ValueError("Labels and predictions must be provided for combined metrics") - - # Handle label encoding for non-numeric labels - if not all(isinstance(x, (int, float)) for x in labels + preds): - le = LabelEncoder() - le.fit(labels + preds) - labels = le.transform(labels) - preds = le.transform(preds) - - f1 = f1_score(labels, preds, average="macro") - accuracy = accuracy_score(labels, preds) - - if return_format == "tuple": - return f1, accuracy - return {"f1": f1, "accuracy": accuracy} - - else: - raise ValueError(f"Unknown metric_type: {metric_type}") - - -def get_layer_freeze_range(pretrained_path): - if not pretrained_path: - return {"min": 0, "max": 0} - - config = AutoConfig.from_pretrained(pretrained_path) - total_layers = config.num_hidden_layers - return {"min": 0, "max": total_layers - 1} - - -def prepare_training_environment(config): - """ - Prepare the training environment by setting seed and loading data. - - Returns: - tuple: (device, train_loader, val_loader, train_cell_id_mapping, - val_cell_id_mapping, num_labels_list) - """ - from .data import prepare_data_loaders - - # Set seed for reproducibility - set_seed(config["seed"]) - - # Set up device - for non-distributed training - if not config.get("distributed_training", False): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if isinstance(model, nn.DataParallel): + model.module.config.to_json_file( + os.path.join(model_save_directory, "config.json") + ) else: - # For distributed training, device will be set per process - device = None - - # Load data using the streaming dataset - data = prepare_data_loaders(config) - - # For distributed training, we'll set up samplers later in the distributed worker - # Don't create DistributedSampler here as process group isn't initialized yet - - return ( - device, - data["train_loader"], - data["val_loader"], - data["train_cell_mapping"], - data["val_cell_mapping"], - data["num_labels_list"], - ) + model.config.to_json_file(os.path.join(model_save_directory, "config.json")) + print(f"Model and configuration saved to {model_save_directory}") -# Optuna hyperparameter optimization utilities -def save_trial_callback(study, trial, trials_result_path): - """ - Callback to save Optuna trial results to a file. - - Args: - study: Optuna study object - trial: Current trial object - trials_result_path: Path to save trial results - """ - with open(trials_result_path, "a") as f: - f.write( - f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n" - ) +def calculate_task_specific_metrics(task_true_labels, task_pred_labels): + task_metrics = {} + for task_name in task_true_labels.keys(): + true_labels = task_true_labels[task_name] + pred_labels = task_pred_labels[task_name] + f1 = f1_score(true_labels, pred_labels, average="macro") + accuracy = accuracy_score(true_labels, pred_labels) + task_metrics[task_name] = {"f1": f1, "accuracy": accuracy} + return task_metrics -def create_optuna_study(objective, n_trials: int, trials_result_path: str, tensorboard_log_dir: str) -> optuna.Study: - """Create and run an Optuna study with TensorBoard logging.""" - from optuna.integration import TensorBoardCallback - - study = optuna.create_study(direction="maximize") - study.optimize( - objective, - n_trials=n_trials, - callbacks=[ - lambda study, trial: save_trial_callback(study, trial, trials_result_path), - TensorBoardCallback(dirname=tensorboard_log_dir, metric_name="F1 Macro") - ] - ) - return study +def calculate_combined_f1(combined_labels, combined_preds): + # Initialize the LabelEncoder + le = LabelEncoder() -@contextmanager -def setup_logging(config): - run_name = config.get("run_name", "manual_run") - log_dir = os.path.join(config["tensorboard_log_dir"], run_name) - writer = SummaryWriter(log_dir=log_dir) - try: - yield writer - finally: - writer.close() + # Fit and transform combined labels and predictions to numerical values + le.fit(combined_labels + combined_preds) + encoded_true_labels = le.transform(combined_labels) + encoded_pred_labels = le.transform(combined_preds) + # Print out the mapping for sanity check + print("\nLabel Encoder Mapping:") + for index, class_label in enumerate(le.classes_): + print(f"'{class_label}': {index}") -def log_training_step(loss, writer, config, epoch, steps_per_epoch, batch_idx): - """Log training step metrics to TensorBoard and optionally W&B.""" - writer.add_scalar( - "Training Loss", loss, epoch * steps_per_epoch + batch_idx - ) - if config.get("use_wandb", False): - import wandb - wandb.log({"Training Loss": loss}) + # Calculate accuracy + accuracy = accuracy_score(encoded_true_labels, encoded_pred_labels) + # Calculate F1 Macro score + f1 = f1_score(encoded_true_labels, encoded_pred_labels, average="macro") -def log_validation_metrics(task_metrics, val_loss, config, writer, epoch): - """Log validation metrics to console, TensorBoard, and optionally W&B.""" - for task_name, metrics in task_metrics.items(): - print( - f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}" - ) - if config.get("use_wandb", False): - import wandb - wandb.log( - { - f"{task_name} Validation F1 Macro": metrics["f1"], - f"{task_name} Validation Accuracy": metrics["accuracy"], - } - ) + return f1, accuracy - writer.add_scalar("Validation Loss", val_loss, epoch) - for task_name, metrics in task_metrics.items(): - writer.add_scalar(f"{task_name} - Validation F1 Macro", metrics["f1"], epoch) - writer.add_scalar( - f"{task_name} - Validation Accuracy", metrics["accuracy"], epoch - ) +# def save_model_without_heads(original_model_save_directory): +# # Create a new directory for the model without heads +# new_model_save_directory = original_model_save_directory + "_No_Heads" +# if not os.path.exists(new_model_save_directory): +# os.makedirs(new_model_save_directory) -def load_label_mappings(results_dir: str, task_names: List[str]) -> Dict[str, Dict]: - """Load or initialize task label mappings.""" - label_mappings_path = os.path.join(results_dir, "task_label_mappings_val.pkl") - if os.path.exists(label_mappings_path): - with open(label_mappings_path, 'rb') as f: - return pickle.load(f) - return {task_name: {} for task_name in task_names} +# # Load the model state dictionary +# model_state_dict = torch.load( +# os.path.join(original_model_save_directory, "pytorch_model.bin") +# ) +# # Initialize a new BERT model without the classification heads +# config = BertConfig.from_pretrained( +# os.path.join(original_model_save_directory, "config.json") +# ) +# model_without_heads = BertModel(config) -def create_prediction_row(sample_idx: int, val_cell_indices: Dict, task_true_labels: Dict, - task_pred_labels: Dict, task_pred_probs: Dict, task_names: List[str], - inverted_mappings: Dict, val_cell_mapping: Dict) -> Dict: - """Create a row for validation predictions.""" - batch_cell_idx = val_cell_indices.get(sample_idx) - cell_id = val_cell_mapping.get(batch_cell_idx, f"unknown_cell_{sample_idx}") if batch_cell_idx is not None else f"unknown_cell_{sample_idx}" - - row = {"Cell ID": cell_id} - for task_name in task_names: - if task_name in task_true_labels and sample_idx < len(task_true_labels[task_name]): - true_idx = task_true_labels[task_name][sample_idx] - pred_idx = task_pred_labels[task_name][sample_idx] - true_label = inverted_mappings.get(task_name, {}).get(true_idx, f"Unknown-{true_idx}") - pred_label = inverted_mappings.get(task_name, {}).get(pred_idx, f"Unknown-{pred_idx}") - - row.update({ - f"{task_name}_true_idx": true_idx, - f"{task_name}_pred_idx": pred_idx, - f"{task_name}_true_label": true_label, - f"{task_name}_pred_label": pred_label - }) - - if task_name in task_pred_probs and sample_idx < len(task_pred_probs[task_name]): - probs = task_pred_probs[task_name][sample_idx] - if isinstance(probs, (list, np.ndarray)) or (hasattr(probs, '__iter__') and not isinstance(probs, str)): - prob_list = list(probs) if not isinstance(probs, list) else probs - row[f"{task_name}_all_probs"] = ",".join(map(str, prob_list)) - for class_idx, prob in enumerate(prob_list): - class_label = inverted_mappings.get(task_name, {}).get(class_idx, f"Unknown-{class_idx}") - row[f"{task_name}_prob_{class_label}"] = prob - else: - row[f"{task_name}_all_probs"] = str(probs) - - return row +# # Filter the state dict to exclude classification heads +# model_without_heads_state_dict = { +# k: v +# for k, v in model_state_dict.items() +# if not k.startswith("classification_heads") +# } +# # Load the filtered state dict into the model +# model_without_heads.load_state_dict(model_without_heads_state_dict, strict=False) -def save_validation_predictions( - val_cell_indices, - task_true_labels, - task_pred_labels, - task_pred_probs, - config, - trial_number=None, -): - """Save validation predictions to a CSV file with class labels and probabilities.""" - os.makedirs(config["results_dir"], exist_ok=True) - - if trial_number is not None: - os.makedirs(os.path.join(config["results_dir"], f"trial_{trial_number}"), exist_ok=True) - val_preds_file = os.path.join(config["results_dir"], f"trial_{trial_number}/val_preds.csv") - else: - val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv") - - if not val_cell_indices or not task_true_labels: - pd.DataFrame().to_csv(val_preds_file, index=False) - return - - try: - label_mappings = load_label_mappings(config["results_dir"], config["task_names"]) - inverted_mappings = {task: {idx: label for label, idx in mapping.items()} for task, mapping in label_mappings.items()} - val_cell_mapping = config.get("val_cell_mapping", {}) - - # Determine maximum number of samples - max_samples = max( - [len(val_cell_indices)] + - [len(task_true_labels[task]) for task in task_true_labels] - ) - - rows = [ - create_prediction_row( - sample_idx, val_cell_indices, task_true_labels, task_pred_labels, - task_pred_probs, config["task_names"], inverted_mappings, val_cell_mapping - ) - for sample_idx in range(max_samples) - ] - - pd.DataFrame(rows).to_csv(val_preds_file, index=False) - except Exception as e: - pd.DataFrame([{"Error": str(e)}]).to_csv(val_preds_file, index=False) +# # Save the model without heads +# model_save_path = os.path.join(new_model_save_directory, "pytorch_model.bin") +# torch.save(model_without_heads.state_dict(), model_save_path) +# # Copy the configuration file +# shutil.copy( +# os.path.join(original_model_save_directory, "config.json"), +# new_model_save_directory, +# ) -def setup_distributed_environment(rank, world_size, config): - """ - Setup the distributed training environment. - - Args: - rank (int): The rank of the current process - world_size (int): Total number of processes - config (dict): Configuration dictionary - """ - os.environ['MASTER_ADDR'] = config.get('master_addr', 'localhost') - os.environ['MASTER_PORT'] = config.get('master_port', '12355') - - # Initialize the process group - dist.init_process_group( - backend='nccl', - init_method='env://', - world_size=world_size, - rank=rank - ) - - # Set the device for this process - torch.cuda.set_device(rank) +# print(f"Model without classification heads saved to {new_model_save_directory}") -def train_distributed(trainer_class, config, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, trial_number=None, shared_dict=None): - """Run distributed training across multiple GPUs with fallback to single GPU.""" - world_size = torch.cuda.device_count() - - if world_size <= 1: - print("Distributed training requested but only one GPU found. Falling back to single GPU training.") - config["distributed_training"] = False - trainer = trainer_class(config) - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - trainer.device = device - train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list = trainer.setup( - train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list - ) - val_loss, model = trainer.train( - train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list - ) - model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask") - save_model(model, model_save_directory) - save_hyperparameters(model_save_directory, { - **get_config_value(config, "manual_hyperparameters", {}), - "dropout_rate": config["dropout_rate"], - "use_task_weights": config["use_task_weights"], - "task_weights": config["task_weights"], - "max_layers_to_freeze": config["max_layers_to_freeze"], - "use_attention_pooling": config["use_attention_pooling"], - }) - - if shared_dict is not None: - shared_dict['val_loss'] = val_loss - task_true_labels, task_pred_labels, task_pred_probs = collect_validation_predictions(model, val_loader, device, config) - shared_dict['task_metrics'] = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific") - shared_dict['model_state_dict'] = {k: v.cpu() for k, v in model.state_dict().items()} - - return val_loss, model - - print(f"Using distributed training with {world_size} GPUs") - mp.spawn( - _distributed_worker, - args=(world_size, trainer_class, config, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, trial_number, shared_dict), - nprocs=world_size, - join=True - ) - - if trial_number is None and shared_dict is None: - model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask") - model_path = os.path.join(model_save_directory, "pytorch_model.bin") - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - model = create_model(config, num_labels_list, device) - model.load_state_dict(torch.load(model_path)) - return 0.0, model - - return None - - -def _distributed_worker(rank, world_size, trainer_class, config, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, trial_number=None, shared_dict=None): - """Worker function for distributed training.""" - setup_distributed_environment(rank, world_size, config) - config["local_rank"] = rank - - # Set up distributed samplers - from torch.utils.data import DistributedSampler - from .data import get_data_loader - - train_sampler = DistributedSampler(train_loader.dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False) - val_sampler = DistributedSampler(val_loader.dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False) - - train_loader = get_data_loader(train_loader.dataset, config["batch_size"], sampler=train_sampler, shuffle=False) - val_loader = get_data_loader(val_loader.dataset, config["batch_size"], sampler=val_sampler, shuffle=False) - - if rank == 0: - print(f"Rank {rank}: Training {len(train_sampler)} samples, Validation {len(val_sampler)} samples") - print(f"Total samples across {world_size} GPUs: Training {len(train_sampler) * world_size}, Validation {len(val_sampler) * world_size}") - - # Create and setup trainer - trainer = trainer_class(config) - train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list = trainer.setup( - train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list - ) - - # Train the model - val_loss, model = trainer.train( - train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list - ) - - # Save model only from the main process - if rank == 0: - model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask") - save_model(model, model_save_directory) - - save_hyperparameters(model_save_directory, { - **get_config_value(config, "manual_hyperparameters", {}), - "dropout_rate": config["dropout_rate"], - "use_task_weights": config["use_task_weights"], - "task_weights": config["task_weights"], - "max_layers_to_freeze": config["max_layers_to_freeze"], - "use_attention_pooling": config["use_attention_pooling"], - }) - - # For Optuna trials, store results in shared dictionary - if shared_dict is not None: - shared_dict['val_loss'] = val_loss - - # Run validation on full dataset from rank 0 for consistent metrics - full_val_loader = get_data_loader(val_loader.dataset, config["batch_size"], sampler=None, shuffle=False) - - # Get validation predictions using our utility function - task_true_labels, task_pred_labels, task_pred_probs = collect_validation_predictions( - model, full_val_loader, trainer.device, config - ) - - # Calculate metrics - task_metrics = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific") - shared_dict['task_metrics'] = task_metrics - - # Store model state dict - if isinstance(model, DDP): - model_state_dict = model.module.state_dict() - else: - model_state_dict = model.state_dict() - - shared_dict['model_state_dict'] = {k: v.cpu() for k, v in model_state_dict.items()} - - # Clean up distributed environment - dist.destroy_process_group() - - -def save_model_without_heads(model_directory): +def get_layer_freeze_range(pretrained_path): """ - Save a version of the fine-tuned model without classification heads. - + Dynamically determines the number of layers to freeze based on the model depth from its configuration. Args: - model_directory (str): Path to the directory containing the fine-tuned model + pretrained_path (str): Path to the pretrained model directory or model identifier. + Returns: + dict: A dictionary with 'min' and 'max' keys indicating the range of layers to freeze. """ - import torch - from transformers import BertConfig, BertModel - - # Load the full model - model_path = os.path.join(model_directory, "pytorch_model.bin") - config_path = os.path.join(model_directory, "config.json") - - if not os.path.exists(model_path) or not os.path.exists(config_path): - raise FileNotFoundError(f"Model files not found in {model_directory}") - - # Load the configuration - config = BertConfig.from_json_file(config_path) - - # Load the model state dict - state_dict = torch.load(model_path, map_location=torch.device('cpu')) - - # Create a new model without heads - base_model = BertModel(config) - - # Filter out the classification head parameters - base_model_state_dict = {} - for key, value in state_dict.items(): - # Only keep parameters that belong to the base model (not classification heads) - if not key.startswith('classification_heads') and not key.startswith('attention_pool'): - base_model_state_dict[key] = value - - # Load the filtered state dict into the base model - base_model.load_state_dict(base_model_state_dict, strict=False) - - # Save the model without heads - output_dir = os.path.join(model_directory, "model_without_heads") - os.makedirs(output_dir, exist_ok=True) - - # Save the model weights - torch.save(base_model.state_dict(), os.path.join(output_dir, "pytorch_model.bin")) - - # Save the configuration - base_model.config.to_json_file(os.path.join(output_dir, "config.json")) - - print(f"Model without classification heads saved to {output_dir}") - return output_dir - - -def get_config_value(config: Dict, key: str, default=None): - - return config.get(key, default) - - -def collect_validation_predictions(model, val_loader, device, config) -> tuple: - task_true_labels = {} - task_pred_labels = {} - task_pred_probs = {} - - with torch.no_grad(): - for batch in val_loader: - input_ids = batch["input_ids"].to(device) - attention_mask = batch["attention_mask"].to(device) - labels = [batch["labels"][task_name].to(device) for task_name in config["task_names"]] - _, logits, _ = model(input_ids, attention_mask, labels) - - for sample_idx in range(len(batch["input_ids"])): - for i, task_name in enumerate(config["task_names"]): - if task_name not in task_true_labels: - task_true_labels[task_name] = [] - task_pred_labels[task_name] = [] - task_pred_probs[task_name] = [] - - true_label = batch["labels"][task_name][sample_idx].item() - pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item() - pred_prob = torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy() - - task_true_labels[task_name].append(true_label) - task_pred_labels[task_name].append(pred_label) - task_pred_probs[task_name].append(pred_prob) - - return task_true_labels, task_pred_labels, task_pred_probs + if pretrained_path: + config = AutoConfig.from_pretrained(pretrained_path) + total_layers = config.num_hidden_layers + return {"min": 0, "max": total_layers - 1} + else: + return {"min": 0, "max": 0} diff --git a/geneformer/mtl_classifier.py b/geneformer/mtl_classifier.py index d6f784c3b68d6e58a09a4ca445f62e7833705be5..68ee837a416e27d9e20156100e30718dec6778d0 100644 --- a/geneformer/mtl_classifier.py +++ b/geneformer/mtl_classifier.py @@ -29,7 +29,7 @@ Geneformer multi-task cell classifier. import logging import os -from .mtl import eval_utils, utils, train +from .mtl import eval_utils, train_utils, utils logger = logging.getLogger(__name__) @@ -49,9 +49,7 @@ class MTLClassifier: "max_layers_to_freeze": {None, dict}, "epochs": {None, int}, "tensorboard_log_dir": {None, str}, - "distributed_training": {None, bool}, - "master_addr": {None, str}, - "master_port": {None, str}, + "use_data_parallel": {None, bool}, "use_attention_pooling": {None, bool}, "use_task_weights": {None, bool}, "hyperparameters": {None, dict}, @@ -63,7 +61,6 @@ class MTLClassifier: "max_grad_norm": {None, int, float}, "seed": {None, int}, "trials_result_path": {None, str}, - "gradient_accumulation_steps": {None, int}, } def __init__( @@ -82,9 +79,7 @@ class MTLClassifier: max_layers_to_freeze=None, epochs=1, tensorboard_log_dir="/results/tblogdir", - distributed_training=False, - master_addr="localhost", - master_port="12355", + use_data_parallel=False, use_attention_pooling=True, use_task_weights=True, hyperparameters=None, # Default is None @@ -94,7 +89,6 @@ class MTLClassifier: wandb_project=None, gradient_clipping=False, max_grad_norm=None, - gradient_accumulation_steps=1, # Add this line with default value 1 seed=42, # Default seed value ): """ @@ -123,12 +117,8 @@ class MTLClassifier: | Path to directory to save results tensorboard_log_dir : None, str | Path to directory for Tensorboard logging results - distributed_training : None, bool - | Whether to use distributed data parallel training across multiple GPUs - master_addr : None, str - | Master address for distributed training (default: localhost) - master_port : None, str - | Master port for distributed training (default: 12355) + use_data_parallel : None, bool + | Whether to use data parallelization use_attention_pooling : None, bool | Whether to use attention pooling use_task_weights : None, bool @@ -160,8 +150,6 @@ class MTLClassifier: | Whether to use gradient clipping max_grad_norm : None, int, float | Maximum norm for gradient clipping - gradient_accumulation_steps : None, int - | Number of steps to accumulate gradients before performing a backward/update pass seed : None, int | Random seed """ @@ -177,7 +165,6 @@ class MTLClassifier: self.batch_size = batch_size self.n_trials = n_trials self.study_name = study_name - self.gradient_accumulation_steps = gradient_accumulation_steps if max_layers_to_freeze is None: # Dynamically determine the range of layers to freeze @@ -188,9 +175,7 @@ class MTLClassifier: self.epochs = epochs self.tensorboard_log_dir = tensorboard_log_dir - self.distributed_training = distributed_training - self.master_addr = master_addr - self.master_port = master_port + self.use_data_parallel = use_data_parallel self.use_attention_pooling = use_attention_pooling self.use_task_weights = use_task_weights self.hyperparameters = ( @@ -308,7 +293,7 @@ class MTLClassifier: self.config["manual_hyperparameters"] = self.manual_hyperparameters self.config["use_manual_hyperparameters"] = True - train.run_manual_tuning(self.config) + train_utils.run_manual_tuning(self.config) def validate_additional_options(self, req_var_dict): missing_variable = False @@ -345,7 +330,7 @@ class MTLClassifier: req_var_dict = dict(zip(required_variable_names, required_variables)) self.validate_additional_options(req_var_dict) - train.run_optuna_study(self.config) + train_utils.run_optuna_study(self.config) def load_and_evaluate_test_model( self, diff --git a/geneformer/perturber_utils.py b/geneformer/perturber_utils.py index 6970612fd36941c7bd68c04c37bb1ce2df9cab37..e7091a2f9df2e7fcb944083a3029734bce7a9328 100644 --- a/geneformer/perturber_utils.py +++ b/geneformer/perturber_utils.py @@ -1,6 +1,5 @@ import itertools as it import logging -import os import pickle from collections import defaultdict from pathlib import Path @@ -9,7 +8,6 @@ from typing import List import numpy as np import pandas as pd import torch -import datasets from datasets import Dataset, load_from_disk from peft import LoraConfig, get_peft_model from transformers import ( @@ -19,6 +17,11 @@ from transformers import ( BitsAndBytesConfig, ) +GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl" +TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl" +ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl" + + logger = logging.getLogger(__name__) @@ -110,25 +113,15 @@ def slice_by_inds_to_perturb(filtered_input_data, cell_inds_to_perturb): # load model to GPU def load_model(model_type, num_classes, model_directory, mode, quantize=False): - if model_type == "Pretrained-Quantized": - inference_only = True - model_type = "Pretrained" - quantize = True - elif model_type == "MTLCellClassifier-Quantized": - inference_only = True + if model_type == "MTLCellClassifier-Quantized": model_type = "MTLCellClassifier" quantize = True - else: - inference_only = False output_hidden_states = (mode == "eval") # Quantization logic - if isinstance(quantize, dict): - quantize_config = quantize.get("bnb_config", None) - peft_config = quantize.get("peft_config", None) - elif quantize: - if inference_only: + if quantize: + if model_type == "MTLCellClassifier": quantize_config = BitsAndBytesConfig(load_in_8bit=True) peft_config = None else: @@ -138,22 +131,13 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False): bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) - try: - peft_config = LoraConfig( - lora_alpha=128, - lora_dropout=0.1, - r=64, - bias="none", - task_type="TokenClassification", - ) - except ValueError as e: - peft_config = LoraConfig( - lora_alpha=128, - lora_dropout=0.1, - r=64, - bias="none", - task_type="TOKEN_CLS", - ) + peft_config = LoraConfig( + lora_alpha=128, + lora_dropout=0.1, + r=64, + bias="none", + task_type="TokenClassification", + ) else: quantize_config = None peft_config = None @@ -190,34 +174,17 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False): model.eval() # Handle device placement and PEFT - adapter_config_path = os.path.join(model_directory, "adapter_config.json") - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not quantize: # Only move non-quantized models - move_to_cuda(model) - elif os.path.exists(adapter_config_path): - # If adapter files exist, load them into the model using PEFT's from_pretrained - model = PeftModel.from_pretrained(model, model_directory) - move_to_cuda(model) - print("loading lora weights") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) elif peft_config: - # Apply PEFT for quantized models (except MTLCellClassifier and CellClassifier-QuantInf) + # Apply PEFT for quantized models (except MTLCellClassifier) model.enable_input_require_grads() model = get_peft_model(model, peft_config) - move_to_cuda(model) return model - -def move_to_cuda(model): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # get what device model is currently on - model_device = next(model.parameters()).device - # Check if the model is on the CPU and move to cuda if necessary - if (model_device.type == "cpu") and (device.type == "cuda"): - model.to(device) - - def quant_layers(model): layer_nums = [] for name, parameter in model.named_parameters(): @@ -431,11 +398,6 @@ def remove_perturbed_indices_set( def make_perturbation_batch( example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc ) -> tuple[Dataset, List[int]]: - - # For datasets>=4.0.0, convert to dict to avoid format issues - if int(datasets.__version__.split(".")[0]) >= 4: - example_cell = example_cell[:] - if combo_lvl == 0 and tokens_to_perturb == "all": if perturb_type in ["overexpress", "activate"]: range_start = 1 @@ -508,11 +470,6 @@ def make_perturbation_batch( def make_perturbation_batch_special( example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc ) -> tuple[Dataset, List[int]]: - - # For datasets>=4.0.0, convert to dict to avoid format issues - if int(datasets.__version__.split(".")[0]) >= 4: - example_cell = example_cell[:] - if combo_lvl == 0 and tokens_to_perturb == "all": if perturb_type in ["overexpress", "activate"]: range_start = 1 @@ -913,4 +870,50 @@ def validate_cell_states_to_model(cell_states_to_model): "'goal_state': 'nf', " "'alt_states': ['hcm', 'other1', 'other2']}" ) - raise \ No newline at end of file + raise + + +class GeneIdHandler: + def __init__(self, raise_errors=False): + def invert_dict(dict_obj): + return {v: k for k, v in dict_obj.items()} + + self.raise_errors = raise_errors + + with open(TOKEN_DICTIONARY_FILE, "rb") as f: + self.gene_token_dict = pickle.load(f) + self.token_gene_dict = invert_dict(self.gene_token_dict) + + with open(ENSEMBL_DICTIONARY_FILE, "rb") as f: + self.id_gene_dict = pickle.load(f) + self.gene_id_dict = invert_dict(self.id_gene_dict) + + def ens_to_token(self, ens_id): + if not self.raise_errors: + return self.gene_token_dict.get(ens_id, ens_id) + else: + return self.gene_token_dict[ens_id] + + def token_to_ens(self, token): + if not self.raise_errors: + return self.token_gene_dict.get(token, token) + else: + return self.token_gene_dict[token] + + def ens_to_symbol(self, ens_id): + if not self.raise_errors: + return self.gene_id_dict.get(ens_id, ens_id) + else: + return self.gene_id_dict[ens_id] + + def symbol_to_ens(self, symbol): + if not self.raise_errors: + return self.id_gene_dict.get(symbol, symbol) + else: + return self.id_gene_dict[symbol] + + def token_to_symbol(self, token): + return self.ens_to_symbol(self.token_to_ens(token)) + + def symbol_to_token(self, symbol): + return self.ens_to_token(self.symbol_to_ens(symbol)) diff --git a/geneformer/pretrainer.py b/geneformer/pretrainer.py index b1af8b8b8d204b8bc6a3003037918465f4a54a92..93a47b0363f73fb06f5abc9ca4e67cf95abf0166 100644 --- a/geneformer/pretrainer.py +++ b/geneformer/pretrainer.py @@ -8,12 +8,13 @@ import math import pickle import warnings from enum import Enum -from typing import Dict, List, Optional, Union +from typing import Dict, Iterator, List, Optional, Union import numpy as np import torch from datasets import Dataset from packaging import version +from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import RandomSampler from transformers import ( BatchEncoding, @@ -23,8 +24,11 @@ from transformers import ( ) from transformers.file_utils import is_datasets_available, is_sagemaker_dp_enabled from transformers.trainer_pt_utils import ( + DistributedLengthGroupedSampler, + DistributedSamplerWithLoop, LengthGroupedSampler, ) +from transformers.training_args import ParallelMode from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj from transformers.utils.generic import _is_tensorflow, _is_torch @@ -603,7 +607,7 @@ class GeneformerPretrainer(Trainer): ) super().__init__(*args, **kwargs) - # updated to not use distributed sampler since Trainer now distributes with accelerate + # modify LengthGroupedSampler to avoid dataset[length_column_name] hanging def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: if not isinstance(self.train_dataset, collections.abc.Sized): return None @@ -626,15 +630,181 @@ class GeneformerPretrainer(Trainer): if self.tokenizer is not None else None ) - return LengthGroupedSampler( + if self.args.world_size <= 1: + return LengthGroupedSampler( dataset=self.train_dataset, batch_size=self.args.train_batch_size, lengths=lengths, model_input_name=model_input_name, generator=generator, + ) + else: + return CustomDistributedLengthGroupedSampler( + dataset=self.train_dataset, + batch_size=self.args.train_batch_size, + num_replicas=self.args.world_size, + rank=self.args.process_index, + lengths=lengths, + model_input_name=model_input_name, + seed=self.args.seed, + ) + + else: + if self.args.world_size <= 1: + if _is_torch_generator_available: + return RandomSampler(self.train_dataset, generator=generator) + return RandomSampler(self.train_dataset) + elif ( + self.args.parallel_mode + in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL] + and not self.args.dataloader_drop_last + ): + # Use a loop for TPUs when drop_last is False to have all batches have the same size. + return DistributedSamplerWithLoop( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + num_replicas=self.args.world_size, + rank=self.args.process_index, + seed=self.args.seed, + ) + else: + return DistributedSampler( + self.train_dataset, + num_replicas=self.args.world_size, + rank=self.args.process_index, + seed=self.args.seed, + ) + + +class CustomDistributedLengthGroupedSampler(DistributedLengthGroupedSampler): + r""" + Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same + length while keeping a bit of randomness. + """ + + # Copied and adapted from PyTorch DistributedSampler. + def __init__( + self, + dataset: Dataset, + batch_size: int, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + seed: int = 0, + drop_last: bool = False, + lengths: Optional[List[int]] = None, + model_input_name: Optional[str] = None, + ): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + self.dataset = dataset + self.batch_size = batch_size + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.drop_last = drop_last + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + if self.drop_last and len(self.dataset) % self.num_replicas != 0: + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil( + (len(self.dataset) - self.num_replicas) / self.num_replicas ) + else: + self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) + self.total_size = self.num_samples * self.num_replicas + self.seed = seed + self.model_input_name = ( + model_input_name if model_input_name is not None else "input_ids" + ) + + if lengths is None: + print("Lengths is none - calculating lengths.") + if ( + not ( + isinstance(dataset[0], dict) + or isinstance(dataset[0], BatchEncoding) + ) + or self.model_input_name not in dataset[0] + ): + raise ValueError( + "Can only automatically infer lengths for datasets whose items are dictionaries with an " + f"'{self.model_input_name}' key." + ) + lengths = [len(feature[self.model_input_name]) for feature in dataset] + self.lengths = lengths + + def __iter__(self) -> Iterator: + # Deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + + indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g) + if not self.drop_last: + # add extra samples to make it evenly divisible + indices += indices[: (self.total_size - len(indices))] else: - if _is_torch_generator_available: - return RandomSampler(self.train_dataset, generator=generator) - return RandomSampler(self.train_dataset) + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + +def get_length_grouped_indices( + lengths, batch_size, mega_batch_mult=None, generator=None +): + """ + Return a list of indices so that each slice of :obj:`batch_size` consecutive indices correspond to elements of + similar lengths. To do this, the indices are: + + - randomly permuted + - grouped in mega-batches of size :obj:`mega_batch_mult * batch_size` + - sorted by length in each mega-batch + + The result is the concatenation of all mega-batches, with the batch of :obj:`batch_size` containing the element of + maximum length placed first, so that an OOM happens sooner rather than later. + """ + # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller. + if mega_batch_mult is None: + # mega_batch_mult = min(len(lengths) // (batch_size * 4), 50) + mega_batch_mult = min(len(lengths) // (batch_size * 4), 1000) + # Just in case, for tiny datasets + if mega_batch_mult == 0: + mega_batch_mult = 1 + + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + indices = torch.randperm(len(lengths), generator=generator) + megabatch_size = mega_batch_mult * batch_size + megabatches = [ + indices[i : i + megabatch_size].tolist() + for i in range(0, len(lengths), megabatch_size) + ] + megabatches = [ + list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) + for megabatch in megabatches + ] + + # The rest is to get the biggest batch first. + # Since each megabatch is sorted by descending length, the longest element is the first + megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches] + max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item() + # Switch to put the longest element in first position + megabatches[0][0], megabatches[max_idx][0] = ( + megabatches[max_idx][0], + megabatches[0][0], + ) + + return [item for sublist in megabatches for item in sublist] diff --git a/geneformer/token_dictionary_gc104M.pkl b/geneformer/token_dictionary_gc95M.pkl similarity index 100% rename from geneformer/token_dictionary_gc104M.pkl rename to geneformer/token_dictionary_gc95M.pkl diff --git a/geneformer/tokenizer.py b/geneformer/tokenizer.py index 8af0cfa0f336d007feb2b144129a96c88ad8a871..3f389200a5384d6aef94be51e2b5304ff372d640 100644 --- a/geneformer/tokenizer.py +++ b/geneformer/tokenizer.py @@ -3,7 +3,7 @@ Geneformer tokenizer. **Input data:** -| *Required format:* raw counts scRNAseq data without feature selection as .loom, .h5ad, or .zarr file. +| *Required format:* raw counts scRNAseq data without feature selection as .loom or anndata file. | *Required row (gene) attribute:* "ensembl_id"; Ensembl ID for each gene. | *Required col (cell) attribute:* "n_counts"; total read counts in that cell. @@ -20,9 +20,9 @@ Geneformer tokenizer. **Description:** -| Input data is a directory with .loom, .h5ad, or .zarr files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function. +| Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function. -| The discussion below references the .loom file format, but the analagous labels are required for .h5ad and .zarr files, just that they will be column instead of row attributes and vice versa due to the transposed format of the file types. +| The discussion below references the .loom file format, but the analagous labels are required for .h5ad files, just that they will be column instead of row attributes and vice versa due to the transposed format of the two file types. | Genes should be labeled with Ensembl IDs (loom row attribute "ensembl_id"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (loom column attribute "n_counts") to be used for normalization. @@ -30,9 +30,11 @@ Geneformer tokenizer. | Additionally, if the original .loom file contains a cell column attribute called "filter_pass", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with "1" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset. -| If one's data is in other formats besides .loom, .h5ad, or .zarr, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom, .h5ad, or .zarr format prior to running the transcriptome tokenizer. +| If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer. -| OF NOTE: Use model_version to auto-select settings for model version other than current default. For V1 model series (original Geneformer pretrained in 2021 on ~30M cells), one must use correct corresponding token dictionary and gene median file, set special_token to False, and set model_input_size to 2048. This argument enables auto-selection of these settings. (For V2 model series, special_token must be True and model_input_size is 4096.) +| OF NOTE: Take care that the correct token dictionary and gene median file is used for the correct model. + +| OF NOTE: For 95M model series, special_token should be True and model_input_size should be 4096. For 30M model series, special_token should be False and model_input_size should be 2048. """ @@ -46,7 +48,6 @@ from collections import Counter from pathlib import Path from typing import Literal -import anndata as ad import loompy as lp import numpy as np import pandas as pd @@ -87,8 +88,6 @@ def sum_ensembl_ids( collapse_gene_ids, gene_mapping_dict, gene_token_dict, - custom_attr_name_dict, - use_h5ad_index, file_format="loom", chunk_size=512, ): @@ -104,45 +103,33 @@ def sum_ensembl_ids( assert ( "ensembl_id_collapsed" not in data.ra.keys() ), "'ensembl_id_collapsed' column already exists in data.ra.keys()" - - assert ( - "n_counts" in data.ca.keys() - ), "'n_counts' column missing from data.ca.keys()" - - if custom_attr_name_dict is not None: - for label in custom_attr_name_dict: - assert label in data.ca.keys(), f"Attribute `{label}` not present in dataset features" - - # Get the ensembl ids that exist in data - ensembl_ids = data.ra.ensembl_id # Check for duplicate Ensembl IDs if collapse_gene_ids is False. # Comparing to gene_token_dict here, would not perform any mapping steps - if not collapse_gene_ids: - ensembl_id_check = [ - gene for gene in ensembl_ids if gene in gene_token_dict.keys() - ] - if len(ensembl_id_check) == len(set(ensembl_id_check)): + gene_ids_in_dict = [ + gene for gene in data.ra.ensembl_id if gene in gene_token_dict.keys() + ] + if collapse_gene_ids is False: + + if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)): return data_directory else: raise ValueError("Error: data Ensembl IDs non-unique.") - - # Get the genes that exist in the mapping dictionary and the value of those genes - genes_in_map_dict = [gene for gene in ensembl_ids if gene in gene_mapping_dict.keys()] - vals_from_map_dict = [gene_mapping_dict.get(gene) for gene in genes_in_map_dict] - - # if the genes in the mapping dict and the value of those genes are of the same length, - # simply return the mapped values - if(len(set(genes_in_map_dict)) == len(set(vals_from_map_dict))): - mapped_vals = [gene_mapping_dict.get(gene.upper()) for gene in data.ra["ensembl_id"]] - data.ra["ensembl_id_collapsed"] = mapped_vals + + gene_ids_collapsed = [ + gene_mapping_dict.get(gene_id.upper()) for gene_id in data.ra.ensembl_id + ] + gene_ids_collapsed_in_dict = [ + gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys() + ] + + if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed_in_dict)): + data.ra["ensembl_id_collapsed"] = gene_ids_collapsed return data_directory - # Genes need to be collapsed else: dedup_filename = data_directory.with_name( data_directory.stem + "__dedup.loom" ) - mapped_vals = [gene_mapping_dict.get(gene.upper()) for gene in data.ra["ensembl_id"]] - data.ra["ensembl_id_collapsed"] = mapped_vals + data.ra["ensembl_id_collapsed"] = gene_ids_collapsed dup_genes = [ idx for idx, count in Counter(data.ra["ensembl_id_collapsed"]).items() @@ -201,19 +188,13 @@ def sum_ensembl_ids( dsout.add_columns(processed_array, col_attrs=view.ca) return dedup_filename - elif file_format in ["h5ad", "zarr"]: + elif file_format == "h5ad": """ Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together. Returns adata object with deduplicated Ensembl IDs. """ - if file_format == "h5ad": - data = sc.read_h5ad(str(data_directory)) - else: # zarr - data = ad.read_zarr(str(data_directory)) - - if use_h5ad_index: - data.var["ensembl_id"] = list(data.var.index) + data = sc.read_h5ad(str(data_directory)) assert ( "ensembl_id" in data.var.columns @@ -222,41 +203,33 @@ def sum_ensembl_ids( assert ( "ensembl_id_collapsed" not in data.var.columns ), "'ensembl_id_collapsed' column already exists in data.var" - assert ( - "n_counts" in data.obs.columns - ), "'n_counts' column missing from data.obs" - - if custom_attr_name_dict is not None: - for label in custom_attr_name_dict: - assert label in data.obs.columns, f"Attribute `{label}` not present in data.obs" - - # Get the ensembl ids that exist in data - ensembl_ids = data.var.ensembl_id # Check for duplicate Ensembl IDs if collapse_gene_ids is False. # Comparing to gene_token_dict here, would not perform any mapping steps - if not collapse_gene_ids: - ensembl_id_check = [ - gene for gene in ensembl_ids if gene in gene_token_dict.keys() - ] - if len(ensembl_id_check) == len(set(ensembl_id_check)): + gene_ids_in_dict = [ + gene for gene in data.var.ensembl_id if gene in gene_token_dict.keys() + ] + if collapse_gene_ids is False: + + if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)): return data else: raise ValueError("Error: data Ensembl IDs non-unique.") - # Get the genes that exist in the mapping dictionary and the value of those genes - genes_in_map_dict = [gene for gene in ensembl_ids if gene in gene_mapping_dict.keys()] - vals_from_map_dict = [gene_mapping_dict.get(gene) for gene in genes_in_map_dict] - - # if the genes in the mapping dict and the value of those genes are of the same length, - # simply return the mapped values - if(len(set(genes_in_map_dict)) == len(set(vals_from_map_dict))): - data.var["ensembl_id_collapsed"] = data.var.ensembl_id.str.upper().map(gene_mapping_dict) + # Check for when if collapse_gene_ids is True + gene_ids_collapsed = [ + gene_mapping_dict.get(gene_id.upper()) for gene_id in data.var.ensembl_id + ] + gene_ids_collapsed_in_dict = [ + gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys() + ] + if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed_in_dict)): + data.var["ensembl_id_collapsed"] = data.var.ensembl_id.map(gene_mapping_dict) return data - # Genes need to be collapsed + else: - data.var["ensembl_id_collapsed"] = data.var.ensembl_id.str.upper().map(gene_mapping_dict) - data.var_names = data.var["ensembl_id_collapsed"] + data.var["ensembl_id_collapsed"] = gene_ids_collapsed + data.var_names = gene_ids_collapsed data = data[:, ~data.var.index.isna()] dup_genes = [ idx for idx, count in Counter(data.var_names).items() if count > 1 @@ -305,9 +278,6 @@ class TranscriptomeTokenizer: model_input_size=4096, special_token=True, collapse_gene_ids=True, - use_h5ad_index=False, - keep_counts=False, - model_version="V2", gene_median_file=GENE_MEDIAN_FILE, token_dictionary_file=TOKEN_DICTIONARY_FILE, gene_mapping_file=ENSEMBL_MAPPING_FILE, @@ -327,23 +297,15 @@ class TranscriptomeTokenizer: | Chunk size for anndata tokenizer. model_input_size : int = 4096 | Max input size of model to truncate input to. - | For the V1 model series, should be 2048. For the V2 model series, should be 4096. + | For the 30M model series, should be 2048. For the 95M model series, should be 4096. special_token : bool = True | Adds CLS token before and EOS token after rank value encoding. - | For the V1 model series, should be False. For the V2 model series, should be True. + | For the 30M model series, should be False. For the 95M model series, should be True. collapse_gene_ids : bool = True | Whether to collapse gene IDs based on gene mapping dictionary. - use_h5ad_index : bool = False - | use index as Ensembl IDs (only available for h5ad, only if collapse_gene_ids is True) - keep_counts : bool = False - | Whether to keep a dataset column that represents gene counts normalized by total cell counts - | Counts will be ordered by the gene rank order within the tokenized rank value encoding for each cell. - model_version : str - | To auto-select settings for model version other than current default. - | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells gene_median_file : Path | Path to pickle file containing dictionary of non-zero median - | gene expression values across Genecorpus. + | gene expression values across Genecorpus-30M. token_dictionary_file : Path | Path to pickle file containing token dictionary (Ensembl IDs:token). gene_mapping_file : None, Path @@ -365,22 +327,8 @@ class TranscriptomeTokenizer: # add CLS and EOS tokens self.special_token = special_token - # CHANGE DEFAULTS TO BE FOR MODEL OTHER THAN CURRENT - self.model_version = model_version - if self.model_version not in ["V1","V2"]: - logger.error( - "Unrecognized model version. Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells." - ) - elif self.model_version == "V1": - self.model_input_size = 2048 - self.special_token = False - from . import ENSEMBL_MAPPING_FILE_30M, GENE_MEDIAN_FILE_30M, TOKEN_DICTIONARY_FILE_30M - gene_median_file = GENE_MEDIAN_FILE_30M - token_dictionary_file = TOKEN_DICTIONARY_FILE_30M - gene_mapping_file = ENSEMBL_MAPPING_FILE_30M - # load dictionary of gene normalization factors - # (non-zero median value of expression across Genecorpus) + # (non-zero median value of expression across Genecorpus-30M) with open(gene_median_file, "rb") as f: self.gene_median_dict = pickle.load(f) @@ -403,18 +351,12 @@ class TranscriptomeTokenizer: "" in self.gene_token_dict.keys() ): logger.warning( - " and are in gene_token_dict but special_token = False. Please note that for V2 model series, special_token should be True." + " and are in gene_token_dict but special_token = False. Please note that for 95M model series, special_token should be True." ) # if collapsing duplicate gene IDs self.collapse_gene_ids = collapse_gene_ids - # if using h5ad index as ensembl_ids - self.use_h5ad_index = use_h5ad_index - - # if keeping counts within dataset column - self.keep_counts = keep_counts - # load gene mappings dictionary (Ensembl IDs:Ensembl ID) if gene_mapping_file is not None: with open(gene_mapping_file, "rb") as f: @@ -439,8 +381,7 @@ class TranscriptomeTokenizer: data_directory: Path | str, output_directory: Path | str, output_prefix: str, - file_format: Literal["loom", "h5ad", "zarr"] = "loom", - input_identifier: str = "", + file_format: Literal["loom", "h5ad"] = "loom", use_generator: bool = False, ): """ @@ -455,21 +396,17 @@ class TranscriptomeTokenizer: output_prefix : str | Prefix for output .dataset file_format : str - | Format of input files. Can be "loom", "h5ad", or "zarr". - input_identifier : str - | Substring identifier for input .loom, .h5ad, or .zarr, only matches are tokenized - | Default is no identifier, tokenizes all files in provided directory. + | Format of input files. Can be "loom" or "h5ad". use_generator : bool | Whether to use generator or dict for tokenization. """ - tokenized_cells, cell_metadata, tokenized_counts = self.tokenize_files( - Path(data_directory), file_format, input_identifier + tokenized_cells, cell_metadata = self.tokenize_files( + Path(data_directory), file_format ) tokenized_dataset = self.create_dataset( tokenized_cells, cell_metadata, - tokenized_counts, use_generator=use_generator, ) @@ -477,10 +414,9 @@ class TranscriptomeTokenizer: tokenized_dataset.save_to_disk(str(output_path)) def tokenize_files( - self, data_directory, file_format: Literal["loom", "h5ad", "zarr"] = "loom", input_identifier: str = "" + self, data_directory, file_format: Literal["loom", "h5ad"] = "loom" ): tokenized_cells = [] - tokenized_counts = [] if self.custom_attr_name_dict is not None: cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()] cell_metadata = { @@ -489,20 +425,15 @@ class TranscriptomeTokenizer: # loops through directories to tokenize .loom files file_found = 0 - # loops through directories to tokenize .loom, .h5ad, or .zarr files + # loops through directories to tokenize .loom or .h5ad files tokenize_file_fn = ( self.tokenize_loom if file_format == "loom" else self.tokenize_anndata ) - if input_identifier == "": - file_match = f"*.{file_format}" - else: - file_match = f"*{input_identifier}*.{file_format}" - for file_path in data_directory.glob(file_match): + for file_path in data_directory.glob(f"*.{file_format}"): file_found = 1 print(f"Tokenizing {file_path}") - file_tokenized_cells, file_cell_metadata, file_tokenized_counts = tokenize_file_fn(file_path, file_format=file_format) + file_tokenized_cells, file_cell_metadata = tokenize_file_fn(file_path) tokenized_cells += file_tokenized_cells - tokenized_counts += file_tokenized_counts if self.custom_attr_name_dict is not None: for k in cell_attr: cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[ @@ -516,17 +447,15 @@ class TranscriptomeTokenizer: f"No .{file_format} files found in directory {data_directory}." ) raise - return tokenized_cells, cell_metadata, tokenized_counts + return tokenized_cells, cell_metadata - def tokenize_anndata(self, adata_file_path, target_sum=10_000, file_format="h5ad"): + def tokenize_anndata(self, adata_file_path, target_sum=10_000): adata = sum_ensembl_ids( adata_file_path, self.collapse_gene_ids, self.gene_mapping_dict, self.gene_token_dict, - self.custom_attr_name_dict, - self.use_h5ad_index, - file_format=file_format, + file_format="h5ad", chunk_size=self.chunk_size, ) @@ -565,7 +494,6 @@ class TranscriptomeTokenizer: filter_pass_loc = np.array([i for i in range(adata.shape[0])]) tokenized_cells = [] - tokenized_counts = [] for i in range(0, len(filter_pass_loc), self.chunk_size): idx = filter_pass_loc[i : i + self.chunk_size] @@ -573,23 +501,14 @@ class TranscriptomeTokenizer: n_counts = adata[idx].obs["n_counts"].values[:, None] X_view0 = adata[idx, :].X X_view = X_view0[:, coding_miRNA_loc] - X_norm_unscaled = X_view / n_counts * target_sum - X_norm = X_norm_unscaled / norm_factor_vector + X_norm = X_view / n_counts * target_sum / norm_factor_vector X_norm = sp.csr_matrix(X_norm) - X_norm_unscaled = sp.csr_matrix(X_norm_unscaled) tokenized_cells += [ rank_genes(X_norm[i].data, coding_miRNA_tokens[X_norm[i].indices]) for i in range(X_norm.shape[0]) ] - if self.keep_counts: - X_norm_unscaled = sp.csr_matrix(X_norm_unscaled) - tokenized_counts += [ - rank_genes(X_norm[i].data, X_norm_unscaled[i].data) - for i in range(X_norm.shape[0]) - ] - # add custom attributes for subview to dict if self.custom_attr_name_dict is not None: for k in file_cell_metadata.keys(): @@ -597,28 +516,9 @@ class TranscriptomeTokenizer: else: file_cell_metadata = None - # ensure no tokenized_cells are empty - empty_cell_indices = [i for i, cell in enumerate(tokenized_cells) if cell.size == 0] - if len(empty_cell_indices) > 0: - logger.warning( - "Warning: cells without any genes in token dictionary detected. This is unusual and may indicate empty droplets or otherwise invalid cells within the input data. Consider further QC prior to tokenization. Proceeding with excluding empty cells." - ) - empty_cell_indices.sort(reverse=True) # for safe deletion - for index in empty_cell_indices: - del tokenized_cells[index] - if self.keep_counts: - del tokenized_counts[index] - # remove corresponding metadata - for k,v in file_cell_metadata.items(): - for index in empty_cell_indices: - del v[index] - file_cell_metadata[k] = v - - return tokenized_cells, file_cell_metadata, tokenized_counts + return tokenized_cells, file_cell_metadata - def tokenize_loom(self, loom_file_path, target_sum=10_000, file_format="loom"): - tokenized_counts = [] # keep_counts not implemented for tokenize_loom - + def tokenize_loom(self, loom_file_path, target_sum=10_000): if self.custom_attr_name_dict is not None: file_cell_metadata = { attr_key: [] for attr_key in self.custom_attr_name_dict.keys() @@ -631,9 +531,7 @@ class TranscriptomeTokenizer: self.collapse_gene_ids, self.gene_mapping_dict, self.gene_token_dict, - self.custom_attr_name_dict, - use_h5ad_index=False, - file_format=file_format, + file_format="loom", chunk_size=self.chunk_size, ) @@ -706,33 +604,29 @@ class TranscriptomeTokenizer: del data.ra["ensembl_id_collapsed"] - return tokenized_cells, file_cell_metadata, tokenized_counts + return tokenized_cells, file_cell_metadata def create_dataset( self, tokenized_cells, cell_metadata, - tokenized_counts, use_generator=False, keep_uncropped_input_ids=False, ): print("Creating dataset.") # create dict for dataset creation dataset_dict = {"input_ids": tokenized_cells} - if self.keep_counts: - dataset_dict["counts"] = tokenized_counts - if self.custom_attr_name_dict is not None: dataset_dict.update(cell_metadata) # create dataset if use_generator: + def dict_generator(): for i in range(len(tokenized_cells)): yield {k: dataset_dict[k][i] for k in dataset_dict.keys()} output_dataset = Dataset.from_generator(dict_generator, num_proc=self.nproc) - else: output_dataset = Dataset.from_dict(dataset_dict) @@ -755,23 +649,9 @@ class TranscriptomeTokenizer: len(example["input_ids"]), self.gene_token_dict.get(""), ) - if self.keep_counts: - example["counts"] = example["counts"][ - 0 : self.model_input_size - 2 - ] # truncate to leave space for CLS and EOS token - example["counts"] = np.insert( - example["counts"], 0, 0.0 - ) - example["counts"] = np.insert( - example["counts"], - len(example["counts"]), - 0.0, - ) else: # Truncate/Crop input_ids to input size example["input_ids"] = example["input_ids"][0 : self.model_input_size] - if self.keep_counts: - example["counts"] = example["counts"][0 : self.model_input_size] example["length"] = len(example["input_ids"]) return example diff --git a/generation_config.json b/generation_config.json index 0786a9f4dc0de68ee18cbf78399931b05fbefee7..6f690c1f39b5b262e6b898b8891afd9d44978f11 100644 --- a/generation_config.json +++ b/generation_config.json @@ -1,5 +1,5 @@ { "_from_model_config": true, "pad_token_id": 0, - "transformers_version": "4.44.2" + "transformers_version": "4.37.1" } diff --git a/gf-12L-30M-i2048/config.json b/gf-12L-30M-i2048/config.json new file mode 100644 index 0000000000000000000000000000000000000000..52a12424cea85facdf0ca0c507908506daae7ea7 --- /dev/null +++ b/gf-12L-30M-i2048/config.json @@ -0,0 +1,23 @@ +{ + "architectures": [ + "BertForMaskedLM" + ], + "attention_probs_dropout_prob": 0.02, + "gradient_checkpointing": false, + "hidden_act": "relu", + "hidden_dropout_prob": 0.02, + "hidden_size": 512, + "initializer_range": 0.02, + "intermediate_size": 1024, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 2048, + "model_type": "bert", + "num_attention_heads": 8, + "num_hidden_layers": 12, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.6.0", + "type_vocab_size": 2, + "use_cache": true, + "vocab_size": 25426 +} diff --git a/gf-12L-30M-i2048/pytorch_model.bin b/gf-12L-30M-i2048/pytorch_model.bin new file mode 100644 index 0000000000000000000000000000000000000000..d706ef2ff77fd6809a91034e6ed24af0e1b33999 --- /dev/null +++ b/gf-12L-30M-i2048/pytorch_model.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:812f8d85e5ecf9d64c268f052f6ece2c1906bc4f1aecf70d5144b2598386b615 +size 158467410 diff --git a/gf-12L-30M-i2048/training_args.bin b/gf-12L-30M-i2048/training_args.bin new file mode 100644 index 0000000000000000000000000000000000000000..346383caaaa3b555cb6fcd8de8e4982ebf2a50d5 --- /dev/null +++ b/gf-12L-30M-i2048/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:259cf6067211e24e198690d00f0a222ee5550ad57e23d04ced0d0ca2e1b3738e +size 2607 diff --git a/Geneformer-V2-316M/config.json b/gf-12L-95M-i4096/config.json similarity index 63% rename from Geneformer-V2-316M/config.json rename to gf-12L-95M-i4096/config.json index 6bc648aa565eabb748a6a43ee4def5032a0d5237..86e20c35e6f257f0daeb00ebb92a0751d12d8fff 100755 --- a/Geneformer-V2-316M/config.json +++ b/gf-12L-95M-i4096/config.json @@ -2,22 +2,22 @@ "architectures": [ "BertForMaskedLM" ], - "attention_probs_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.02, "classifier_dropout": null, "hidden_act": "relu", - "hidden_dropout_prob": 0.1, - "hidden_size": 1152, + "hidden_dropout_prob": 0.02, + "hidden_size": 512, "initializer_range": 0.02, - "intermediate_size": 4608, + "intermediate_size": 1024, "layer_norm_eps": 1e-12, "max_position_embeddings": 4096, "model_type": "bert", - "num_attention_heads": 18, - "num_hidden_layers": 18, + "num_attention_heads": 8, + "num_hidden_layers": 12, "pad_token_id": 0, "position_embedding_type": "absolute", "torch_dtype": "float32", - "transformers_version": "4.44.2", + "transformers_version": "4.37.1", "type_vocab_size": 2, "use_cache": true, "vocab_size": 20275 diff --git a/Geneformer-V2-104M_CLcancer/generation_config.json b/gf-12L-95M-i4096/generation_config.json similarity index 100% rename from Geneformer-V2-104M_CLcancer/generation_config.json rename to gf-12L-95M-i4096/generation_config.json diff --git a/gf-12L-95M-i4096/model.safetensors b/gf-12L-95M-i4096/model.safetensors new file mode 100755 index 0000000000000000000000000000000000000000..1069352219a29bed65fa8e13feb77004128174fa --- /dev/null +++ b/gf-12L-95M-i4096/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4365ba23e393fcfa0e65a94ac64a0983cd788bd23a8d4914f4ab66f85cfe043c +size 152012980 diff --git a/gf-12L-95M-i4096/training_args.bin b/gf-12L-95M-i4096/training_args.bin new file mode 100755 index 0000000000000000000000000000000000000000..18802f485a03e0262866d1ef7a3e4748a3b14ed3 --- /dev/null +++ b/gf-12L-95M-i4096/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:21a45980734b138029422e95a5601def858821a9ec02cd473938b9f525ac108d +size 4920 diff --git a/Geneformer-V2-104M_CLcancer/config.json b/gf-12L-95M-i4096_CLcancer/config.json similarity index 57% rename from Geneformer-V2-104M_CLcancer/config.json rename to gf-12L-95M-i4096_CLcancer/config.json index 8af4fee97b1880ff66cb6a183e824da7318c1416..a7793eb2ea27b28f1f4c5b9974d30c98b4afe8a6 100755 --- a/Geneformer-V2-104M_CLcancer/config.json +++ b/gf-12L-95M-i4096_CLcancer/config.json @@ -1,19 +1,19 @@ { - "_name_or_path": "/gladstone/theodoris/lab/ctheodoris/gf-104m/models/241127_143148_geneformer_94M_L12_emb768_SL4096_E3_B18_LR0.0002_LScosine_WR0.007_Oadamw_DS13/models", + "_name_or_path": "/gladstone/theodoris/lab/pretrained_models/encoder/240402_194213_geneformer_94M_L12_emb512_SL4096_E3_B4_LR0.0005_LScosine_WU5000_Oadamw_DS8/models", "architectures": [ "BertForMaskedLM" ], - "attention_probs_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.02, "classifier_dropout": null, "hidden_act": "relu", - "hidden_dropout_prob": 0.1, - "hidden_size": 768, + "hidden_dropout_prob": 0.02, + "hidden_size": 512, "initializer_range": 0.02, - "intermediate_size": 3072, + "intermediate_size": 1024, "layer_norm_eps": 1e-12, "max_position_embeddings": 4096, "model_type": "bert", - "num_attention_heads": 12, + "num_attention_heads": 8, "num_hidden_layers": 12, "pad_token_id": 0, "position_embedding_type": "absolute", diff --git a/Geneformer-V2-316M/generation_config.json b/gf-12L-95M-i4096_CLcancer/generation_config.json similarity index 61% rename from Geneformer-V2-316M/generation_config.json rename to gf-12L-95M-i4096_CLcancer/generation_config.json index 0786a9f4dc0de68ee18cbf78399931b05fbefee7..6f690c1f39b5b262e6b898b8891afd9d44978f11 100755 --- a/Geneformer-V2-316M/generation_config.json +++ b/gf-12L-95M-i4096_CLcancer/generation_config.json @@ -1,5 +1,5 @@ { "_from_model_config": true, "pad_token_id": 0, - "transformers_version": "4.44.2" + "transformers_version": "4.37.1" } diff --git a/gf-12L-95M-i4096_CLcancer/model.safetensors b/gf-12L-95M-i4096_CLcancer/model.safetensors new file mode 100755 index 0000000000000000000000000000000000000000..cc620ee4b4243b7ab6d83ad518563e1425eab45b --- /dev/null +++ b/gf-12L-95M-i4096_CLcancer/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2451adeed240c165634fea60ccba17063da8a2843ea9fcdcc0ce185720bf0dc2 +size 152012980 diff --git a/gf-12L-95M-i4096_CLcancer/training_args.bin b/gf-12L-95M-i4096_CLcancer/training_args.bin new file mode 100755 index 0000000000000000000000000000000000000000..1669f5848710ca4a53db6e118e50b816f85381b7 --- /dev/null +++ b/gf-12L-95M-i4096_CLcancer/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37074f3ea62a6ba0a312c38526c20c2dccbb068a2c7ee8c7c73b435dd90ab7b1 +size 5048 diff --git a/gf-20L-95M-i4096/config.json b/gf-20L-95M-i4096/config.json new file mode 100755 index 0000000000000000000000000000000000000000..db949ba1ae442ad3b9e52fd8b7922c6b936ef98c --- /dev/null +++ b/gf-20L-95M-i4096/config.json @@ -0,0 +1,24 @@ +{ + "architectures": [ + "BertForMaskedLM" + ], + "attention_probs_dropout_prob": 0.02, + "classifier_dropout": null, + "hidden_act": "relu", + "hidden_dropout_prob": 0.02, + "hidden_size": 896, + "initializer_range": 0.02, + "intermediate_size": 1792, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 4096, + "model_type": "bert", + "num_attention_heads": 14, + "num_hidden_layers": 20, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "torch_dtype": "float32", + "transformers_version": "4.37.1", + "type_vocab_size": 2, + "use_cache": true, + "vocab_size": 20275 +} diff --git a/Geneformer-V2-104M/generation_config.json b/gf-20L-95M-i4096/generation_config.json similarity index 61% rename from Geneformer-V2-104M/generation_config.json rename to gf-20L-95M-i4096/generation_config.json index 0786a9f4dc0de68ee18cbf78399931b05fbefee7..6f690c1f39b5b262e6b898b8891afd9d44978f11 100755 --- a/Geneformer-V2-104M/generation_config.json +++ b/gf-20L-95M-i4096/generation_config.json @@ -1,5 +1,5 @@ { "_from_model_config": true, "pad_token_id": 0, - "transformers_version": "4.44.2" + "transformers_version": "4.37.1" } diff --git a/gf-20L-95M-i4096/model.safetensors b/gf-20L-95M-i4096/model.safetensors new file mode 100755 index 0000000000000000000000000000000000000000..37212863afb501a17425dd48766d71d534537d24 --- /dev/null +++ b/gf-20L-95M-i4096/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:db85c081a6d392448955c7d0185e26aba74507518df991ca8c69ee9108ce8bbf +size 605292732 diff --git a/gf-20L-95M-i4096/training_args.bin b/gf-20L-95M-i4096/training_args.bin new file mode 100755 index 0000000000000000000000000000000000000000..3db61b0b99d299afb7c4a237d2b531baa253e5d3 --- /dev/null +++ b/gf-20L-95M-i4096/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5afed602918d6f0c4916c1b9335bcdb619bca2c6fd6c7e0dd2a86d195264b8cc +size 5048 diff --git a/Geneformer-V1-10M/config.json b/gf-6L-30M-i2048/config.json similarity index 100% rename from Geneformer-V1-10M/config.json rename to gf-6L-30M-i2048/config.json diff --git a/Geneformer-V1-10M/model.safetensors b/gf-6L-30M-i2048/model.safetensors similarity index 100% rename from Geneformer-V1-10M/model.safetensors rename to gf-6L-30M-i2048/model.safetensors diff --git a/Geneformer-V1-10M/pytorch_model.bin b/gf-6L-30M-i2048/pytorch_model.bin similarity index 100% rename from Geneformer-V1-10M/pytorch_model.bin rename to gf-6L-30M-i2048/pytorch_model.bin diff --git a/Geneformer-V1-10M/training_args.bin b/gf-6L-30M-i2048/training_args.bin similarity index 100% rename from Geneformer-V1-10M/training_args.bin rename to gf-6L-30M-i2048/training_args.bin diff --git a/model.safetensors b/model.safetensors index f9c0c25c4a1df80ddc1455def715a9856152882c..1069352219a29bed65fa8e13feb77004128174fa 100644 --- a/model.safetensors +++ b/model.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:965ceccea81953d362081ef3843560a0e4fef88d396c28017881f1e94b1246f3 -size 1265455076 +oid sha256:4365ba23e393fcfa0e65a94ac64a0983cd788bd23a8d4914f4ab66f85cfe043c +size 152012980 diff --git a/requirements.txt b/requirements.txt index 4b9b25e28193489e3416cc324eefe6db708bc004..be1e0462b733623b0c31dd4fea5e27e5660c23ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ anndata>=0.9 -bitsandbytes>=0.45.5 datasets>=2.12 hyperopt>=0.2 loompy>=3.0 @@ -23,4 +22,4 @@ tdigest>=0.5.2 tensorboard>=2.15 torch>=2.0.1 tqdm>=4.65 -transformers==4.46 \ No newline at end of file +transformers>=4.28 diff --git a/setup.py b/setup.py index 81d947cea5933da94d1ea315dd5eee59f18c6577..6dde9eefad8c76e3d1e41ae187f2215bdbc93db5 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,6 @@ setup( include_package_data=True, install_requires=[ "anndata", - "bitsandbytes", "datasets", "loompy", "matplotlib", diff --git a/training_args.bin b/training_args.bin index 630bab56b199325b337a9969d30167f5b73b7815..18802f485a03e0262866d1ef7a3e4748a3b14ed3 100644 --- a/training_args.bin +++ b/training_args.bin @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:e45150f9a4ca34cb4e91ce79f65f3d99d9d66df9f66a37517a352d291008e0b8 -size 5432 +oid sha256:21a45980734b138029422e95a5601def858821a9ec02cd473938b9f525ac108d +size 4920